1use crate::cache::CacheConfig;
11use crate::cache::policy::{
12 CacheAdmission, CachePolicy, CachePolicyConfig, CachePolicyKind, build_cache_policy,
13};
14use log::debug;
15use lru::LruCache;
16use regex::Regex;
17use std::hash::{Hash, Hasher};
18use std::num::NonZeroUsize;
19use std::sync::atomic::{AtomicU64, Ordering};
20use std::sync::{Arc, Mutex, OnceLock};
21
22#[derive(Clone)]
24pub enum CompiledRegex {
25 Standard(Arc<Regex>),
27 Fancy(Arc<fancy_regex::Regex>),
29}
30
31impl CompiledRegex {
32 #[must_use]
34 pub fn is_match(&self, text: &str) -> bool {
35 match self {
36 CompiledRegex::Standard(re) => re.is_match(text),
37 CompiledRegex::Fancy(re) => re.is_match(text).unwrap_or(false),
38 }
39 }
40}
41
42fn has_lookaround(pattern: &str) -> bool {
44 pattern.contains("(?=")
45 || pattern.contains("(?!")
46 || pattern.contains("(?<=")
47 || pattern.contains("(?<!")
48}
49
50#[derive(Debug)]
52pub enum RegexCompileError {
53 Standard(regex::Error),
55 Fancy(Box<fancy_regex::Error>),
58}
59
60impl std::fmt::Display for RegexCompileError {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 RegexCompileError::Standard(e) => write!(f, "{e}"),
64 RegexCompileError::Fancy(e) => write!(f, "{e}"),
65 }
66 }
67}
68
69impl std::error::Error for RegexCompileError {
70 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
71 match self {
72 RegexCompileError::Standard(e) => Some(e),
73 RegexCompileError::Fancy(e) => Some(e.as_ref()),
74 }
75 }
76}
77
78impl From<regex::Error> for RegexCompileError {
79 fn from(err: regex::Error) -> Self {
80 RegexCompileError::Standard(err)
81 }
82}
83
84impl From<fancy_regex::Error> for RegexCompileError {
85 fn from(err: fancy_regex::Error) -> Self {
86 RegexCompileError::Fancy(Box::new(err))
87 }
88}
89
90#[derive(Clone, Eq, PartialEq)]
92struct RegexCacheKey {
93 pattern: String,
94 case_insensitive: bool,
95 multiline: bool,
96 dot_all: bool,
97}
98
99impl Hash for RegexCacheKey {
100 fn hash<H: Hasher>(&self, state: &mut H) {
101 self.pattern.hash(state);
102 self.case_insensitive.hash(state);
103 self.multiline.hash(state);
104 self.dot_all.hash(state);
105 }
106}
107
108pub struct RegexCache {
110 cache: Arc<Mutex<LruCache<RegexCacheKey, CompiledRegex>>>,
111 capacity: usize,
112 hits: AtomicU64,
113 misses: AtomicU64,
114 evictions: AtomicU64,
115 policy: Arc<dyn CachePolicy<RegexCacheKey>>,
116}
117
118impl RegexCache {
119 #[must_use]
121 pub fn new(capacity: usize) -> Self {
122 let (kind, window_ratio) = Self::policy_params_from_env();
123 Self::with_policy(capacity, kind, window_ratio)
124 }
125
126 pub fn get_or_compile(
139 &self,
140 pattern: &str,
141 case_insensitive: bool,
142 multiline: bool,
143 dot_all: bool,
144 ) -> Result<CompiledRegex, RegexCompileError> {
145 let key = RegexCacheKey {
146 pattern: pattern.to_string(),
147 case_insensitive,
148 multiline,
149 dot_all,
150 };
151
152 self.handle_policy_evictions();
153
154 {
155 let mut cache = self.cache.lock().expect("regex cache mutex poisoned");
156 if let Some(regex) = cache.get(&key) {
157 self.hits.fetch_add(1, Ordering::Relaxed);
158 let _ = self.policy.record_hit(&key);
159 return Ok(regex.clone());
160 }
161 }
162
163 self.misses.fetch_add(1, Ordering::Relaxed);
164
165 let compiled = if has_lookaround(pattern) {
168 let mut flag_prefix = String::new();
171 if case_insensitive {
172 flag_prefix.push_str("(?i)");
173 }
174 if multiline {
175 flag_prefix.push_str("(?m)");
176 }
177 if dot_all {
178 flag_prefix.push_str("(?s)");
179 }
180 let full_pattern = format!("{flag_prefix}{pattern}");
181 let fancy_re = fancy_regex::Regex::new(&full_pattern)?;
182 CompiledRegex::Fancy(Arc::new(fancy_re))
183 } else {
184 let mut builder = regex::RegexBuilder::new(pattern);
186 builder
187 .case_insensitive(case_insensitive)
188 .multi_line(multiline)
189 .dot_matches_new_line(dot_all);
190 let re = builder.build()?;
191 CompiledRegex::Standard(Arc::new(re))
192 };
193
194 if matches!(self.policy.admit(&key, 1), CacheAdmission::Rejected) {
195 debug!(
196 "regex cache policy {:?} rejected pattern {:?}",
197 self.policy.kind(),
198 key.pattern
199 );
200 return Ok(compiled);
201 }
202
203 {
205 let mut cache = self.cache.lock().expect("regex cache mutex poisoned");
206 if cache.len() == self.capacity
207 && let Some((evicted_key, _)) = cache.pop_lru()
208 {
209 self.policy.invalidate(&evicted_key);
210 self.evictions.fetch_add(1, Ordering::Relaxed);
211 }
212 cache.put(key, compiled.clone());
213 }
214
215 self.handle_policy_evictions();
216
217 Ok(compiled)
218 }
219
220 #[cfg(test)]
226 pub fn len(&self) -> usize {
227 self.cache.lock().expect("regex cache mutex poisoned").len()
228 }
229
230 #[cfg(test)]
232 pub fn is_empty(&self) -> bool {
233 self.len() == 0
234 }
235
236 fn handle_policy_evictions(&self) {
237 let evicted = self.policy.drain_evictions();
238 if evicted.is_empty() {
239 return;
240 }
241 let mut cache = self.cache.lock().expect("regex cache mutex poisoned");
242 for eviction in evicted {
243 if cache.pop(&eviction.key).is_some() {
244 self.evictions.fetch_add(1, Ordering::Relaxed);
245 }
246 }
247 }
248
249 fn with_policy(capacity: usize, kind: CachePolicyKind, window_ratio: f32) -> Self {
250 let normalized_capacity = capacity.max(1);
251 let config = CachePolicyConfig::new(kind, normalized_capacity as u64, window_ratio);
252 Self {
253 cache: Arc::new(Mutex::new(LruCache::new(
254 NonZeroUsize::new(normalized_capacity).expect("capacity must be > 0"),
255 ))),
256 capacity: normalized_capacity,
257 hits: AtomicU64::new(0),
258 misses: AtomicU64::new(0),
259 evictions: AtomicU64::new(0),
260 policy: build_cache_policy(&config),
261 }
262 }
263
264 fn policy_params_from_env() -> (CachePolicyKind, f32) {
265 let cfg = CacheConfig::from_env();
266 (cfg.policy_kind(), cfg.policy_window_ratio())
267 }
268
269 #[cfg(test)]
270 fn with_policy_kind(capacity: usize, kind: CachePolicyKind) -> Self {
271 Self::with_policy(capacity, kind, CacheConfig::DEFAULT_POLICY_WINDOW_RATIO)
272 }
273
274 #[cfg(test)]
275 fn policy_metrics(&self) -> crate::cache::policy::CachePolicyMetrics {
276 self.policy.stats()
277 }
278}
279
280static REGEX_CACHE: OnceLock<RegexCache> = OnceLock::new();
282
283fn get_global_cache() -> &'static RegexCache {
284 REGEX_CACHE.get_or_init(|| {
285 let size = std::env::var("SQRY_REGEX_CACHE_SIZE")
286 .ok()
287 .and_then(|s| s.parse::<usize>().ok())
288 .filter(|&s| (1..=10_000).contains(&s))
289 .unwrap_or(100);
290
291 RegexCache::new(size)
292 })
293}
294
295pub fn get_or_compile_regex(
304 pattern: &str,
305 case_insensitive: bool,
306 multiline: bool,
307 dot_all: bool,
308) -> Result<CompiledRegex, RegexCompileError> {
309 get_global_cache().get_or_compile(pattern, case_insensitive, multiline, dot_all)
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use crate::cache::policy::CachePolicyKind;
318
319 #[test]
320 fn test_cache_hit_reuses_compiled_regex() {
321 let cache = RegexCache::new(10);
322
323 let re1 = cache.get_or_compile("foo.*", false, false, false).unwrap();
325 assert_eq!(cache.len(), 1);
326
327 let _re2 = cache.get_or_compile("foo.*", false, false, false).unwrap();
329 assert_eq!(cache.len(), 1);
330 assert!(re1.is_match("foobar"));
332 }
333
334 #[test]
335 fn test_different_flags_create_separate_entries() {
336 let cache = RegexCache::new(10);
337
338 let re1 = cache.get_or_compile("foo", false, false, false).unwrap();
339 let re2 = cache.get_or_compile("foo", true, false, false).unwrap(); assert_eq!(cache.len(), 2); assert!(re1.is_match("foo"));
343 assert!(!re1.is_match("FOO")); assert!(re2.is_match("FOO")); }
346
347 #[test]
348 fn test_lru_eviction_works() {
349 let cache = RegexCache::new(2);
350
351 cache.get_or_compile("a", false, false, false).unwrap();
352 cache.get_or_compile("b", false, false, false).unwrap();
353 assert_eq!(cache.len(), 2);
354
355 cache.get_or_compile("c", false, false, false).unwrap();
357 assert_eq!(cache.len(), 2);
358 }
359
360 #[test]
361 fn test_compilation_errors_not_cached() {
362 let cache = RegexCache::new(10);
363
364 assert!(
366 cache
367 .get_or_compile("[invalid", false, false, false)
368 .is_err()
369 );
370 assert_eq!(cache.len(), 0); }
372
373 #[test]
374 fn tiny_lfu_rejects_cold_bursts() {
375 let cache = RegexCache::with_policy_kind(3, CachePolicyKind::TinyLfu);
376
377 let hot = cache
378 .get_or_compile("hot", false, false, false)
379 .expect("compile hot regex");
380 for _ in 0..10 {
381 let _ = cache
382 .get_or_compile("hot", false, false, false)
383 .expect("warm hot regex");
384 }
385
386 for i in 0..30 {
387 let pattern = format!("cold{i}");
388 let _ = cache
389 .get_or_compile(&pattern, false, false, false)
390 .expect("compile cold regex");
391 }
392
393 let warmed = cache
394 .get_or_compile("hot", false, false, false)
395 .expect("retrieve hot regex");
396 assert!(hot.is_match("hot"));
398 assert!(warmed.is_match("hot"));
399
400 let metrics = cache.policy_metrics();
401 assert!(
402 metrics.lfu_rejects > 0,
403 "expected TinyLFU to reject some cold entries"
404 );
405 }
406
407 #[test]
409 fn test_lookahead_pattern_compiles() {
410 let cache = RegexCache::new(10);
411 let re = cache
412 .get_or_compile("foo(?=bar)", false, false, false)
413 .expect("lookahead should compile");
414 assert!(re.is_match("foobar"));
415 assert!(!re.is_match("foobaz"));
416 }
417
418 #[test]
419 fn test_lookbehind_pattern_compiles() {
420 let cache = RegexCache::new(10);
421 let re = cache
422 .get_or_compile("(?<=test_)foo", false, false, false)
423 .expect("lookbehind should compile");
424 assert!(re.is_match("test_foo"));
425 assert!(!re.is_match("prod_foo"));
426 }
427
428 #[test]
429 fn test_negative_lookahead_pattern() {
430 let cache = RegexCache::new(10);
431 let re = cache
432 .get_or_compile("foo(?!bar)", false, false, false)
433 .expect("negative lookahead should compile");
434 assert!(re.is_match("foobaz"));
435 assert!(!re.is_match("foobar"));
436 }
437
438 #[test]
439 fn test_negative_lookbehind_pattern() {
440 let cache = RegexCache::new(10);
441 let re = cache
442 .get_or_compile("(?<!test_)foo", false, false, false)
443 .expect("negative lookbehind should compile");
444 assert!(re.is_match("prod_foo"));
445 assert!(!re.is_match("test_foo"));
446 }
447
448 #[test]
449 fn test_lookaround_with_flags() {
450 let cache = RegexCache::new(10);
451 let re = cache
453 .get_or_compile("(?<=TEST_)foo", true, false, false)
454 .expect("lookaround with flags should compile");
455 assert!(re.is_match("TEST_foo"));
456 assert!(re.is_match("test_foo")); assert!(re.is_match("TEST_FOO")); }
459}