sanitize_engine/
allowlist.rs1use regex::Regex;
39use std::collections::HashSet;
40use std::sync::atomic::{AtomicU64, Ordering};
41
42pub struct AllowlistMatcher {
59 exact: HashSet<String>,
60 globs: Vec<String>,
61 regexes: Vec<(String, Regex)>,
63 case_sensitive: bool,
66 seen: AtomicU64,
68}
69
70impl AllowlistMatcher {
71 #[must_use]
82 pub fn new(patterns: Vec<String>) -> (Self, Vec<String>) {
83 Self::build(patterns, false)
84 }
85
86 #[must_use]
91 pub fn new_case_sensitive(patterns: Vec<String>) -> (Self, Vec<String>) {
92 Self::build(patterns, true)
93 }
94
95 fn build(patterns: Vec<String>, case_sensitive: bool) -> (Self, Vec<String>) {
96 let mut exact = HashSet::new();
97 let mut globs = Vec::new();
98 let mut regexes = Vec::new();
99 let mut warnings = Vec::new();
100
101 for pat in patterns {
102 if let Some(re_src) = pat.strip_prefix("regex:") {
103 match Regex::new(re_src) {
104 Ok(compiled) => regexes.push((pat, compiled)),
105 Err(e) => warnings.push(format!(
106 "allowlist pattern '{pat}' failed to compile: {e} — pattern skipped"
107 )),
108 }
109 continue;
110 }
111
112 for ch in ['^', '$', '+', '(', ')'] {
113 if ch == '$' && !pat.replace("${", "").contains('$') {
116 continue;
117 }
118 if pat.contains(ch) {
119 warnings.push(format!(
120 "allowlist pattern '{pat}' contains regex character '{ch}'; \
121 it is matched literally — use the 'regex:' prefix for regex syntax"
122 ));
123 break;
124 }
125 }
126 let stored = if case_sensitive {
129 pat
130 } else {
131 pat.to_lowercase()
132 };
133 if stored.contains('*') {
134 globs.push(stored);
135 } else {
136 exact.insert(stored);
137 }
138 }
139
140 (
141 Self {
142 exact,
143 globs,
144 regexes,
145 case_sensitive,
146 seen: AtomicU64::new(0),
147 },
148 warnings,
149 )
150 }
151
152 pub fn is_allowed(&self, value: &str) -> bool {
156 self.match_pattern(value).is_some()
157 }
158
159 pub fn match_pattern<'a>(&'a self, value: &str) -> Option<&'a str> {
171 let normalized: std::borrow::Cow<str> = if self.case_sensitive {
173 std::borrow::Cow::Borrowed(value)
174 } else {
175 std::borrow::Cow::Owned(value.to_lowercase())
176 };
177 if let Some(s) = self.exact.get(normalized.as_ref()) {
178 self.seen.fetch_add(1, Ordering::Relaxed);
179 return Some(s.as_str());
180 }
181 for pat in &self.globs {
182 if glob_matches(pat, &normalized) {
183 self.seen.fetch_add(1, Ordering::Relaxed);
184 return Some(pat.as_str());
185 }
186 }
187 for (pat_str, re) in &self.regexes {
190 if re.is_match(value) {
191 self.seen.fetch_add(1, Ordering::Relaxed);
192 return Some(pat_str.as_str());
193 }
194 }
195 None
196 }
197
198 pub fn seen_count(&self) -> u64 {
200 self.seen.load(Ordering::Relaxed)
201 }
202
203 pub fn pattern_count(&self) -> usize {
205 self.exact.len() + self.globs.len() + self.regexes.len()
206 }
207
208 pub fn is_empty(&self) -> bool {
210 self.exact.is_empty() && self.globs.is_empty() && self.regexes.is_empty()
211 }
212}
213
214pub(crate) fn glob_matches(pattern: &str, value: &str) -> bool {
219 let parts: Vec<&str> = pattern.split('*').collect();
220 let n = parts.len();
221
222 if !value.starts_with(parts[0]) {
224 return false;
225 }
226 if !value.ends_with(parts[n - 1]) {
228 return false;
229 }
230 if n == 2 {
232 return value.len() >= parts[0].len() + parts[n - 1].len();
235 }
236
237 let mut pos = parts[0].len();
239 let end = value.len().saturating_sub(parts[n - 1].len());
240 for part in &parts[1..n - 1] {
241 if part.is_empty() {
242 continue;
243 }
244 match value[pos..end].find(part) {
245 Some(found) => pos += found + part.len(),
246 None => return false,
247 }
248 }
249 true
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 fn matcher(pats: &[&str]) -> AllowlistMatcher {
257 let (m, _) = AllowlistMatcher::new(pats.iter().map(|s| (*s).to_string()).collect());
258 m
259 }
260
261 fn matcher_cs(pats: &[&str]) -> AllowlistMatcher {
262 let (m, _) =
263 AllowlistMatcher::new_case_sensitive(pats.iter().map(|s| (*s).to_string()).collect());
264 m
265 }
266
267 #[test]
268 fn exact_match() {
269 let m = matcher(&["localhost", "127.0.0.1"]);
271 assert!(m.is_allowed("localhost"));
272 assert!(m.is_allowed("127.0.0.1"));
273 assert!(m.is_allowed("Localhost")); assert!(m.is_allowed("LOCALHOST")); assert!(!m.is_allowed("localhost2")); }
277
278 #[test]
279 fn exact_match_case_sensitive() {
280 let m = matcher_cs(&["localhost", "127.0.0.1"]);
281 assert!(m.is_allowed("localhost"));
282 assert!(!m.is_allowed("Localhost")); assert!(!m.is_allowed("LOCALHOST"));
284 }
285
286 #[test]
287 fn glob_suffix() {
288 let m = matcher(&["*.internal"]);
289 assert!(m.is_allowed("db.internal"));
290 assert!(m.is_allowed("staging.db.internal"));
291 assert!(!m.is_allowed("db.internal.evil"));
292 assert!(!m.is_allowed("internal"));
293 }
294
295 #[test]
296 fn glob_prefix() {
297 let m = matcher(&["192.168.1.*"]);
298 assert!(m.is_allowed("192.168.1.1"));
299 assert!(m.is_allowed("192.168.1.255"));
300 assert!(!m.is_allowed("192.168.2.1"));
301 assert!(m.is_allowed("192.168.1."));
303 }
304
305 #[test]
306 fn glob_middle() {
307 let m = matcher(&["user-*@corp.com"]);
308 assert!(m.is_allowed("user-alice@corp.com"));
309 assert!(m.is_allowed("user-bob@corp.com"));
310 assert!(!m.is_allowed("admin@corp.com"));
311 assert!(!m.is_allowed("user-alice@other.com"));
312 }
313
314 #[test]
315 fn glob_star_only() {
316 let m = matcher(&["*"]);
317 assert!(m.is_allowed("anything"));
318 assert!(m.is_allowed(""));
319 }
320
321 #[test]
322 fn seen_counter() {
323 let m = matcher(&["ok"]);
324 assert_eq!(m.seen_count(), 0);
325 m.is_allowed("ok");
326 m.is_allowed("ok");
327 m.is_allowed("not-ok");
328 assert_eq!(m.seen_count(), 2);
329 }
330
331 #[test]
332 fn regex_char_warning() {
333 let (_, warnings) = AllowlistMatcher::new(vec!["^bad$".into()]);
334 assert!(!warnings.is_empty());
335 }
336
337 #[test]
338 fn empty_allowlist_is_empty() {
339 let m = matcher(&[]);
340 assert!(m.is_empty());
341 assert!(!m.is_allowed("anything"));
342 }
343
344 #[test]
347 fn match_pattern_returns_exact_pattern() {
348 let m = matcher(&["localhost"]);
349 assert_eq!(m.match_pattern("localhost"), Some("localhost"));
350 assert_eq!(m.match_pattern("other"), None);
351 }
352
353 #[test]
354 fn match_pattern_returns_glob_pattern() {
355 let m = matcher(&["*.internal"]);
356 assert_eq!(m.match_pattern("db.internal"), Some("*.internal"));
357 assert_eq!(m.match_pattern("github.com"), None);
358 }
359
360 #[test]
361 fn match_pattern_returns_first_matching_pattern() {
362 let m = matcher(&["*.internal", "db.*"]);
363 assert_eq!(m.match_pattern("db.internal"), Some("*.internal"));
365 }
366
367 #[test]
368 fn match_pattern_increments_seen_counter() {
369 let m = matcher(&["ok"]);
370 assert_eq!(m.seen_count(), 0);
371 m.match_pattern("ok");
372 assert_eq!(m.seen_count(), 1);
373 m.match_pattern("not-ok");
374 assert_eq!(m.seen_count(), 1);
375 }
376
377 #[test]
378 fn is_allowed_delegates_to_match_pattern() {
379 let m = matcher(&["*.internal"]);
380 assert!(m.is_allowed("db.internal"));
381 assert!(!m.is_allowed("github.com"));
382 assert_eq!(m.seen_count(), 1);
384 }
385
386 #[test]
389 fn glob_multiple_wildcards() {
390 let m = matcher(&["a*b*c"]);
391 assert!(m.is_allowed("abc"));
392 assert!(m.is_allowed("aXbYc"));
393 assert!(m.is_allowed("aXXXbYYYc"));
394 assert!(!m.is_allowed("abX"));
395 assert!(!m.is_allowed("Xbc"));
396 }
397
398 #[test]
399 fn glob_adjacent_wildcards_treated_as_one() {
400 let m = matcher(&["a**b"]);
401 assert!(m.is_allowed("ab"));
402 assert!(m.is_allowed("aXb"));
403 assert!(!m.is_allowed("ba"));
404 }
405
406 #[test]
407 fn glob_empty_value_only_matches_star() {
408 let m = matcher(&["*"]);
409 assert!(m.is_allowed(""));
410 let m2 = matcher(&["a*"]);
411 assert!(!m2.is_allowed(""));
412 }
413
414 #[test]
415 fn glob_prefix_suffix_overlap_rejected() {
416 let m = matcher(&["a*b"]);
418 assert!(!m.is_allowed("a"));
419 assert!(!m.is_allowed("b"));
420 assert!(m.is_allowed("ab"));
421 assert!(m.is_allowed("aXb"));
422 }
423
424 #[test]
425 fn large_exact_list_all_match() {
426 let words: Vec<String> = (0..500).map(|i| format!("word{i}")).collect();
428 let (m, _) = AllowlistMatcher::new(words.clone());
429 for w in &words {
430 assert!(m.is_allowed(w), "should allow {w}");
431 }
432 assert!(!m.is_allowed("word500"));
433 assert!(!m.is_allowed("notaword"));
434 }
435
436 #[test]
437 fn exact_and_glob_coexist() {
438 let m = matcher(&["localhost", "127.0.0.1", "*.internal"]);
439 assert!(m.is_allowed("localhost"));
440 assert!(m.is_allowed("127.0.0.1"));
441 assert!(m.is_allowed("db.internal"));
442 assert!(!m.is_allowed("github.com"));
443 }
444
445 #[test]
448 fn regex_basic_match() {
449 let m = matcher(&["regex:^192\\.168\\.[0-9]+\\.[0-9]+$"]);
450 assert!(m.is_allowed("192.168.1.1"));
451 assert!(m.is_allowed("192.168.100.255"));
452 assert!(!m.is_allowed("192.168.1.")); assert!(!m.is_allowed("10.0.0.1"));
454 }
455
456 #[test]
457 fn regex_substring_match_without_anchors() {
458 let m = matcher(&["regex:internal"]);
460 assert!(m.is_allowed("db.internal.corp"));
461 assert!(m.is_allowed("internal"));
462 assert!(!m.is_allowed("external"));
463 }
464
465 #[test]
466 fn regex_anchored_full_match() {
467 let m = matcher(&["regex:^token-[A-Z]{3}-[0-9]{4}$"]);
468 assert!(m.is_allowed("token-ABC-1234"));
469 assert!(!m.is_allowed("token-AB-1234")); assert!(!m.is_allowed("xtoken-ABC-1234")); }
472
473 #[test]
474 fn regex_case_sensitive_by_default() {
475 let m = matcher(&["regex:^localhost$"]);
477 assert!(m.is_allowed("localhost"));
478 assert!(!m.is_allowed("LOCALHOST"));
479 assert!(!m.is_allowed("Localhost"));
480 }
481
482 #[test]
483 fn regex_case_insensitive_via_flag() {
484 let m = matcher(&["regex:(?i)^localhost$"]);
485 assert!(m.is_allowed("localhost"));
486 assert!(m.is_allowed("LOCALHOST"));
487 assert!(m.is_allowed("LocalHost"));
488 }
489
490 #[test]
491 fn regex_invalid_pattern_produces_warning_and_is_skipped() {
492 let (m, warnings) = AllowlistMatcher::new(vec!["regex:[invalid".into()]);
493 assert!(!warnings.is_empty(), "invalid regex must produce a warning");
494 assert!(warnings[0].contains("failed to compile"));
495 assert!(!m.is_allowed("anything"));
497 assert_eq!(m.pattern_count(), 0);
498 }
499
500 #[test]
501 fn regex_match_pattern_returns_full_prefixed_string() {
502 let m = matcher(&["regex:^10\\.0\\."]);
503 assert_eq!(m.match_pattern("10.0.1.5"), Some("regex:^10\\.0\\."),);
504 assert_eq!(m.match_pattern("192.168.1.1"), None);
505 }
506
507 #[test]
508 fn regex_seen_counter_increments() {
509 let m = matcher(&["regex:^test"]);
510 assert_eq!(m.seen_count(), 0);
511 m.is_allowed("test-value");
512 m.is_allowed("test-value");
513 m.is_allowed("other");
514 assert_eq!(m.seen_count(), 2);
515 }
516
517 #[test]
518 fn regex_coexists_with_exact_and_glob() {
519 let m = matcher(&[
520 "localhost",
521 "*.internal",
522 "regex:^10\\.[0-9]+\\.[0-9]+\\.[0-9]+$",
523 ]);
524 assert!(m.is_allowed("localhost"));
525 assert!(m.is_allowed("db.internal"));
526 assert!(m.is_allowed("10.0.0.1"));
527 assert!(m.is_allowed("10.255.255.255"));
528 assert!(!m.is_allowed("192.168.1.1"));
529 assert!(!m.is_allowed("github.com"));
530 assert_eq!(m.pattern_count(), 3);
531 }
532
533 #[test]
534 fn regex_not_subject_to_case_insensitive_lowercasing() {
535 let m = matcher(&["regex:^[A-Z]{3}$"]); assert!(m.is_allowed("ABC"));
539 assert!(!m.is_allowed("abc")); }
541
542 #[test]
543 fn metacharacter_warning_updated_to_suggest_regex_prefix() {
544 let (_, warnings) = AllowlistMatcher::new(vec!["^bad$".into()]);
545 assert!(!warnings.is_empty());
546 assert!(
547 warnings[0].contains("regex:"),
548 "warning should suggest regex: prefix, got: {}",
549 warnings[0],
550 );
551 }
552}