sanitize_engine/
allowlist.rs1use regex::Regex;
39use std::collections::HashSet;
40use std::sync::atomic::{AtomicU64, Ordering};
41
42pub struct AllowlistMatcher {
59 exact: HashSet<String>,
60 globs: Vec<String>,
61 regexes: Vec<(String, Regex)>,
63 case_sensitive: bool,
66 seen: AtomicU64,
68}
69
70impl AllowlistMatcher {
71 #[must_use]
82 pub fn new(patterns: Vec<String>) -> (Self, Vec<String>) {
83 Self::build(patterns, false)
84 }
85
86 #[must_use]
91 pub fn new_case_sensitive(patterns: Vec<String>) -> (Self, Vec<String>) {
92 Self::build(patterns, true)
93 }
94
95 fn build(patterns: Vec<String>, case_sensitive: bool) -> (Self, Vec<String>) {
96 let mut exact = HashSet::new();
97 let mut globs = Vec::new();
98 let mut regexes = Vec::new();
99 let mut warnings = Vec::new();
100
101 for pat in patterns {
102 if let Some(re_src) = pat.strip_prefix("regex:") {
103 match Regex::new(re_src) {
104 Ok(compiled) => regexes.push((pat, compiled)),
105 Err(e) => warnings.push(format!(
106 "allowlist pattern '{pat}' failed to compile: {e} — pattern skipped"
107 )),
108 }
109 continue;
110 }
111
112 for ch in ['^', '$', '+', '(', ')'] {
113 if pat.contains(ch) {
114 warnings.push(format!(
115 "allowlist pattern '{pat}' contains regex character '{ch}'; \
116 it is matched literally — use the 'regex:' prefix for regex syntax"
117 ));
118 break;
119 }
120 }
121 let stored = if case_sensitive {
124 pat
125 } else {
126 pat.to_lowercase()
127 };
128 if stored.contains('*') {
129 globs.push(stored);
130 } else {
131 exact.insert(stored);
132 }
133 }
134
135 (
136 Self {
137 exact,
138 globs,
139 regexes,
140 case_sensitive,
141 seen: AtomicU64::new(0),
142 },
143 warnings,
144 )
145 }
146
147 pub fn is_allowed(&self, value: &str) -> bool {
151 self.match_pattern(value).is_some()
152 }
153
154 pub fn match_pattern<'a>(&'a self, value: &str) -> Option<&'a str> {
166 let normalized: std::borrow::Cow<str> = if self.case_sensitive {
168 std::borrow::Cow::Borrowed(value)
169 } else {
170 std::borrow::Cow::Owned(value.to_lowercase())
171 };
172 if let Some(s) = self.exact.get(normalized.as_ref()) {
173 self.seen.fetch_add(1, Ordering::Relaxed);
174 return Some(s.as_str());
175 }
176 for pat in &self.globs {
177 if glob_matches(pat, &normalized) {
178 self.seen.fetch_add(1, Ordering::Relaxed);
179 return Some(pat.as_str());
180 }
181 }
182 for (pat_str, re) in &self.regexes {
185 if re.is_match(value) {
186 self.seen.fetch_add(1, Ordering::Relaxed);
187 return Some(pat_str.as_str());
188 }
189 }
190 None
191 }
192
193 pub fn seen_count(&self) -> u64 {
195 self.seen.load(Ordering::Relaxed)
196 }
197
198 pub fn pattern_count(&self) -> usize {
200 self.exact.len() + self.globs.len() + self.regexes.len()
201 }
202
203 pub fn is_empty(&self) -> bool {
205 self.exact.is_empty() && self.globs.is_empty() && self.regexes.is_empty()
206 }
207}
208
209pub(crate) fn glob_matches(pattern: &str, value: &str) -> bool {
214 let parts: Vec<&str> = pattern.split('*').collect();
215 let n = parts.len();
216
217 if !value.starts_with(parts[0]) {
219 return false;
220 }
221 if !value.ends_with(parts[n - 1]) {
223 return false;
224 }
225 if n == 2 {
227 return value.len() >= parts[0].len() + parts[n - 1].len();
230 }
231
232 let mut pos = parts[0].len();
234 let end = value.len().saturating_sub(parts[n - 1].len());
235 for part in &parts[1..n - 1] {
236 if part.is_empty() {
237 continue;
238 }
239 match value[pos..end].find(part) {
240 Some(found) => pos += found + part.len(),
241 None => return false,
242 }
243 }
244 true
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 fn matcher(pats: &[&str]) -> AllowlistMatcher {
252 let (m, _) = AllowlistMatcher::new(pats.iter().map(|s| (*s).to_string()).collect());
253 m
254 }
255
256 fn matcher_cs(pats: &[&str]) -> AllowlistMatcher {
257 let (m, _) =
258 AllowlistMatcher::new_case_sensitive(pats.iter().map(|s| (*s).to_string()).collect());
259 m
260 }
261
262 #[test]
263 fn exact_match() {
264 let m = matcher(&["localhost", "127.0.0.1"]);
266 assert!(m.is_allowed("localhost"));
267 assert!(m.is_allowed("127.0.0.1"));
268 assert!(m.is_allowed("Localhost")); assert!(m.is_allowed("LOCALHOST")); assert!(!m.is_allowed("localhost2")); }
272
273 #[test]
274 fn exact_match_case_sensitive() {
275 let m = matcher_cs(&["localhost", "127.0.0.1"]);
276 assert!(m.is_allowed("localhost"));
277 assert!(!m.is_allowed("Localhost")); assert!(!m.is_allowed("LOCALHOST"));
279 }
280
281 #[test]
282 fn glob_suffix() {
283 let m = matcher(&["*.internal"]);
284 assert!(m.is_allowed("db.internal"));
285 assert!(m.is_allowed("staging.db.internal"));
286 assert!(!m.is_allowed("db.internal.evil"));
287 assert!(!m.is_allowed("internal"));
288 }
289
290 #[test]
291 fn glob_prefix() {
292 let m = matcher(&["192.168.1.*"]);
293 assert!(m.is_allowed("192.168.1.1"));
294 assert!(m.is_allowed("192.168.1.255"));
295 assert!(!m.is_allowed("192.168.2.1"));
296 assert!(m.is_allowed("192.168.1."));
298 }
299
300 #[test]
301 fn glob_middle() {
302 let m = matcher(&["user-*@corp.com"]);
303 assert!(m.is_allowed("user-alice@corp.com"));
304 assert!(m.is_allowed("user-bob@corp.com"));
305 assert!(!m.is_allowed("admin@corp.com"));
306 assert!(!m.is_allowed("user-alice@other.com"));
307 }
308
309 #[test]
310 fn glob_star_only() {
311 let m = matcher(&["*"]);
312 assert!(m.is_allowed("anything"));
313 assert!(m.is_allowed(""));
314 }
315
316 #[test]
317 fn seen_counter() {
318 let m = matcher(&["ok"]);
319 assert_eq!(m.seen_count(), 0);
320 m.is_allowed("ok");
321 m.is_allowed("ok");
322 m.is_allowed("not-ok");
323 assert_eq!(m.seen_count(), 2);
324 }
325
326 #[test]
327 fn regex_char_warning() {
328 let (_, warnings) = AllowlistMatcher::new(vec!["^bad$".into()]);
329 assert!(!warnings.is_empty());
330 }
331
332 #[test]
333 fn empty_allowlist_is_empty() {
334 let m = matcher(&[]);
335 assert!(m.is_empty());
336 assert!(!m.is_allowed("anything"));
337 }
338
339 #[test]
342 fn match_pattern_returns_exact_pattern() {
343 let m = matcher(&["localhost"]);
344 assert_eq!(m.match_pattern("localhost"), Some("localhost"));
345 assert_eq!(m.match_pattern("other"), None);
346 }
347
348 #[test]
349 fn match_pattern_returns_glob_pattern() {
350 let m = matcher(&["*.internal"]);
351 assert_eq!(m.match_pattern("db.internal"), Some("*.internal"));
352 assert_eq!(m.match_pattern("github.com"), None);
353 }
354
355 #[test]
356 fn match_pattern_returns_first_matching_pattern() {
357 let m = matcher(&["*.internal", "db.*"]);
358 assert_eq!(m.match_pattern("db.internal"), Some("*.internal"));
360 }
361
362 #[test]
363 fn match_pattern_increments_seen_counter() {
364 let m = matcher(&["ok"]);
365 assert_eq!(m.seen_count(), 0);
366 m.match_pattern("ok");
367 assert_eq!(m.seen_count(), 1);
368 m.match_pattern("not-ok");
369 assert_eq!(m.seen_count(), 1);
370 }
371
372 #[test]
373 fn is_allowed_delegates_to_match_pattern() {
374 let m = matcher(&["*.internal"]);
375 assert!(m.is_allowed("db.internal"));
376 assert!(!m.is_allowed("github.com"));
377 assert_eq!(m.seen_count(), 1);
379 }
380
381 #[test]
384 fn glob_multiple_wildcards() {
385 let m = matcher(&["a*b*c"]);
386 assert!(m.is_allowed("abc"));
387 assert!(m.is_allowed("aXbYc"));
388 assert!(m.is_allowed("aXXXbYYYc"));
389 assert!(!m.is_allowed("abX"));
390 assert!(!m.is_allowed("Xbc"));
391 }
392
393 #[test]
394 fn glob_adjacent_wildcards_treated_as_one() {
395 let m = matcher(&["a**b"]);
396 assert!(m.is_allowed("ab"));
397 assert!(m.is_allowed("aXb"));
398 assert!(!m.is_allowed("ba"));
399 }
400
401 #[test]
402 fn glob_empty_value_only_matches_star() {
403 let m = matcher(&["*"]);
404 assert!(m.is_allowed(""));
405 let m2 = matcher(&["a*"]);
406 assert!(!m2.is_allowed(""));
407 }
408
409 #[test]
410 fn glob_prefix_suffix_overlap_rejected() {
411 let m = matcher(&["a*b"]);
413 assert!(!m.is_allowed("a"));
414 assert!(!m.is_allowed("b"));
415 assert!(m.is_allowed("ab"));
416 assert!(m.is_allowed("aXb"));
417 }
418
419 #[test]
420 fn large_exact_list_all_match() {
421 let words: Vec<String> = (0..500).map(|i| format!("word{i}")).collect();
423 let (m, _) = AllowlistMatcher::new(words.clone());
424 for w in &words {
425 assert!(m.is_allowed(w), "should allow {w}");
426 }
427 assert!(!m.is_allowed("word500"));
428 assert!(!m.is_allowed("notaword"));
429 }
430
431 #[test]
432 fn exact_and_glob_coexist() {
433 let m = matcher(&["localhost", "127.0.0.1", "*.internal"]);
434 assert!(m.is_allowed("localhost"));
435 assert!(m.is_allowed("127.0.0.1"));
436 assert!(m.is_allowed("db.internal"));
437 assert!(!m.is_allowed("github.com"));
438 }
439
440 #[test]
443 fn regex_basic_match() {
444 let m = matcher(&["regex:^192\\.168\\.[0-9]+\\.[0-9]+$"]);
445 assert!(m.is_allowed("192.168.1.1"));
446 assert!(m.is_allowed("192.168.100.255"));
447 assert!(!m.is_allowed("192.168.1.")); assert!(!m.is_allowed("10.0.0.1"));
449 }
450
451 #[test]
452 fn regex_substring_match_without_anchors() {
453 let m = matcher(&["regex:internal"]);
455 assert!(m.is_allowed("db.internal.corp"));
456 assert!(m.is_allowed("internal"));
457 assert!(!m.is_allowed("external"));
458 }
459
460 #[test]
461 fn regex_anchored_full_match() {
462 let m = matcher(&["regex:^token-[A-Z]{3}-[0-9]{4}$"]);
463 assert!(m.is_allowed("token-ABC-1234"));
464 assert!(!m.is_allowed("token-AB-1234")); assert!(!m.is_allowed("xtoken-ABC-1234")); }
467
468 #[test]
469 fn regex_case_sensitive_by_default() {
470 let m = matcher(&["regex:^localhost$"]);
472 assert!(m.is_allowed("localhost"));
473 assert!(!m.is_allowed("LOCALHOST"));
474 assert!(!m.is_allowed("Localhost"));
475 }
476
477 #[test]
478 fn regex_case_insensitive_via_flag() {
479 let m = matcher(&["regex:(?i)^localhost$"]);
480 assert!(m.is_allowed("localhost"));
481 assert!(m.is_allowed("LOCALHOST"));
482 assert!(m.is_allowed("LocalHost"));
483 }
484
485 #[test]
486 fn regex_invalid_pattern_produces_warning_and_is_skipped() {
487 let (m, warnings) = AllowlistMatcher::new(vec!["regex:[invalid".into()]);
488 assert!(!warnings.is_empty(), "invalid regex must produce a warning");
489 assert!(warnings[0].contains("failed to compile"));
490 assert!(!m.is_allowed("anything"));
492 assert_eq!(m.pattern_count(), 0);
493 }
494
495 #[test]
496 fn regex_match_pattern_returns_full_prefixed_string() {
497 let m = matcher(&["regex:^10\\.0\\."]);
498 assert_eq!(m.match_pattern("10.0.1.5"), Some("regex:^10\\.0\\."),);
499 assert_eq!(m.match_pattern("192.168.1.1"), None);
500 }
501
502 #[test]
503 fn regex_seen_counter_increments() {
504 let m = matcher(&["regex:^test"]);
505 assert_eq!(m.seen_count(), 0);
506 m.is_allowed("test-value");
507 m.is_allowed("test-value");
508 m.is_allowed("other");
509 assert_eq!(m.seen_count(), 2);
510 }
511
512 #[test]
513 fn regex_coexists_with_exact_and_glob() {
514 let m = matcher(&[
515 "localhost",
516 "*.internal",
517 "regex:^10\\.[0-9]+\\.[0-9]+\\.[0-9]+$",
518 ]);
519 assert!(m.is_allowed("localhost"));
520 assert!(m.is_allowed("db.internal"));
521 assert!(m.is_allowed("10.0.0.1"));
522 assert!(m.is_allowed("10.255.255.255"));
523 assert!(!m.is_allowed("192.168.1.1"));
524 assert!(!m.is_allowed("github.com"));
525 assert_eq!(m.pattern_count(), 3);
526 }
527
528 #[test]
529 fn regex_not_subject_to_case_insensitive_lowercasing() {
530 let m = matcher(&["regex:^[A-Z]{3}$"]); assert!(m.is_allowed("ABC"));
534 assert!(!m.is_allowed("abc")); }
536
537 #[test]
538 fn metacharacter_warning_updated_to_suggest_regex_prefix() {
539 let (_, warnings) = AllowlistMatcher::new(vec!["^bad$".into()]);
540 assert!(!warnings.is_empty());
541 assert!(
542 warnings[0].contains("regex:"),
543 "warning should suggest regex: prefix, got: {}",
544 warnings[0],
545 );
546 }
547}