rust_expect/expect/
cache.rs

1//! Regex cache for efficient pattern matching.
2//!
3//! This module provides a cache for compiled regular expressions,
4//! avoiding the overhead of recompiling patterns on each use.
5
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::{Arc, RwLock};
9
10use regex::Regex;
11
12/// Default maximum cache size.
13pub const DEFAULT_CACHE_SIZE: usize = 100;
14
15/// A cache for compiled regular expressions.
16///
17/// The cache uses LRU (Least Recently Used) eviction when full.
18pub struct RegexCache {
19    cache: RwLock<LruCache>,
20    max_size: usize,
21    /// Total cache hits (for statistics).
22    total_hits: AtomicUsize,
23    /// Total cache misses (for statistics).
24    total_misses: AtomicUsize,
25}
26
27struct LruCache {
28    entries: HashMap<String, CacheEntry>,
29    order: Vec<String>,
30}
31
32struct CacheEntry {
33    regex: Arc<Regex>,
34    /// Number of times this pattern has been accessed.
35    hits: AtomicUsize,
36}
37
38impl RegexCache {
39    /// Create a new regex cache with the specified maximum size.
40    #[must_use]
41    pub fn new(max_size: usize) -> Self {
42        Self {
43            cache: RwLock::new(LruCache {
44                entries: HashMap::with_capacity(max_size),
45                order: Vec::with_capacity(max_size),
46            }),
47            max_size,
48            total_hits: AtomicUsize::new(0),
49            total_misses: AtomicUsize::new(0),
50        }
51    }
52
53    /// Create a new regex cache with default size.
54    #[must_use]
55    pub fn with_default_size() -> Self {
56        Self::new(DEFAULT_CACHE_SIZE)
57    }
58
59    /// Get or compile a regex pattern.
60    ///
61    /// Returns a cached regex if available, otherwise compiles and caches it.
62    ///
63    /// # Errors
64    ///
65    /// Returns an error if the pattern is invalid.
66    pub fn get_or_compile(&self, pattern: &str) -> Result<Arc<Regex>, regex::Error> {
67        // Try read path first
68        // Note: We recover from lock poisoning since the cache is just an optimization
69        {
70            let cache = self
71                .cache
72                .read()
73                .unwrap_or_else(std::sync::PoisonError::into_inner);
74            if let Some(entry) = cache.entries.get(pattern) {
75                // Track cache hit
76                entry.hits.fetch_add(1, Ordering::Relaxed);
77                self.total_hits.fetch_add(1, Ordering::Relaxed);
78                return Ok(Arc::clone(&entry.regex));
79            }
80        }
81
82        // Track cache miss
83        self.total_misses.fetch_add(1, Ordering::Relaxed);
84
85        // Compile the regex
86        let regex = Regex::new(pattern)?;
87        let regex = Arc::new(regex);
88
89        // Update cache
90        {
91            let mut cache = self
92                .cache
93                .write()
94                .unwrap_or_else(std::sync::PoisonError::into_inner);
95
96            // Double-check after acquiring write lock (another thread may have inserted)
97            if let Some(entry) = cache.entries.get(pattern) {
98                // Count as hit since we're returning a cached entry
99                entry.hits.fetch_add(1, Ordering::Relaxed);
100                return Ok(Arc::clone(&entry.regex));
101            }
102
103            // Evict if necessary
104            if cache.entries.len() >= self.max_size
105                && let Some(oldest) = cache.order.first().cloned()
106            {
107                cache.entries.remove(&oldest);
108                cache.order.remove(0);
109            }
110
111            // Insert new entry
112            cache.entries.insert(
113                pattern.to_string(),
114                CacheEntry {
115                    regex: Arc::clone(&regex),
116                    hits: AtomicUsize::new(1), // First access
117                },
118            );
119            cache.order.push(pattern.to_string());
120        }
121
122        Ok(regex)
123    }
124
125    /// Check if a pattern is cached.
126    #[must_use]
127    pub fn contains(&self, pattern: &str) -> bool {
128        let cache = self
129            .cache
130            .read()
131            .unwrap_or_else(std::sync::PoisonError::into_inner);
132        cache.entries.contains_key(pattern)
133    }
134
135    /// Get the current number of cached patterns.
136    #[must_use]
137    pub fn len(&self) -> usize {
138        let cache = self
139            .cache
140            .read()
141            .unwrap_or_else(std::sync::PoisonError::into_inner);
142        cache.entries.len()
143    }
144
145    /// Check if the cache is empty.
146    #[must_use]
147    pub fn is_empty(&self) -> bool {
148        self.len() == 0
149    }
150
151    /// Clear the cache.
152    pub fn clear(&self) {
153        let mut cache = self
154            .cache
155            .write()
156            .unwrap_or_else(std::sync::PoisonError::into_inner);
157        cache.entries.clear();
158        cache.order.clear();
159    }
160
161    /// Get the maximum cache size.
162    #[must_use]
163    pub const fn max_size(&self) -> usize {
164        self.max_size
165    }
166
167    /// Get cache statistics.
168    #[must_use]
169    pub fn stats(&self) -> CacheStats {
170        let cache = self
171            .cache
172            .read()
173            .unwrap_or_else(std::sync::PoisonError::into_inner);
174
175        CacheStats {
176            size: cache.entries.len(),
177            max_size: self.max_size,
178            total_hits: self.total_hits.load(Ordering::Relaxed),
179            total_misses: self.total_misses.load(Ordering::Relaxed),
180        }
181    }
182
183    /// Get the total number of cache hits.
184    #[must_use]
185    pub fn total_hits(&self) -> usize {
186        self.total_hits.load(Ordering::Relaxed)
187    }
188
189    /// Get the total number of cache misses.
190    #[must_use]
191    pub fn total_misses(&self) -> usize {
192        self.total_misses.load(Ordering::Relaxed)
193    }
194
195    /// Get the cache hit rate as a ratio (0.0 to 1.0).
196    ///
197    /// Returns 1.0 if no accesses have been made.
198    #[must_use]
199    pub fn hit_rate(&self) -> f64 {
200        let hits = self.total_hits.load(Ordering::Relaxed);
201        let misses = self.total_misses.load(Ordering::Relaxed);
202        let total = hits + misses;
203        if total == 0 {
204            1.0
205        } else {
206            hits as f64 / total as f64
207        }
208    }
209}
210
211/// Statistics about a regex cache.
212#[derive(Debug, Clone, Copy, PartialEq, Eq)]
213pub struct CacheStats {
214    /// Current number of cached patterns.
215    pub size: usize,
216    /// Maximum cache size.
217    pub max_size: usize,
218    /// Total cache hits.
219    pub total_hits: usize,
220    /// Total cache misses.
221    pub total_misses: usize,
222}
223
224impl CacheStats {
225    /// Get the cache hit rate as a ratio (0.0 to 1.0).
226    ///
227    /// Returns 1.0 if no accesses have been made.
228    #[must_use]
229    pub fn hit_rate(&self) -> f64 {
230        let total = self.total_hits + self.total_misses;
231        if total == 0 {
232            1.0
233        } else {
234            self.total_hits as f64 / total as f64
235        }
236    }
237}
238
239impl Default for RegexCache {
240    fn default() -> Self {
241        Self::with_default_size()
242    }
243}
244
245/// Global regex cache for shared use.
246pub static GLOBAL_CACHE: std::sync::LazyLock<RegexCache> =
247    std::sync::LazyLock::new(RegexCache::with_default_size);
248
249/// Get or compile a regex using the global cache.
250///
251/// # Errors
252///
253/// Returns an error if the pattern is invalid.
254pub fn get_regex(pattern: &str) -> Result<Arc<Regex>, regex::Error> {
255    GLOBAL_CACHE.get_or_compile(pattern)
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn cache_basic() {
264        let cache = RegexCache::new(10);
265
266        let r1 = cache.get_or_compile(r"\d+").unwrap();
267        let r2 = cache.get_or_compile(r"\d+").unwrap();
268
269        // Should be the same Arc
270        assert!(Arc::ptr_eq(&r1, &r2));
271    }
272
273    #[test]
274    fn cache_eviction() {
275        let cache = RegexCache::new(2);
276
277        cache.get_or_compile(r"a+").unwrap();
278        cache.get_or_compile(r"b+").unwrap();
279        assert_eq!(cache.len(), 2);
280
281        // This should evict "a+"
282        cache.get_or_compile(r"c+").unwrap();
283        assert_eq!(cache.len(), 2);
284        assert!(!cache.contains(r"a+"));
285        assert!(cache.contains(r"b+"));
286        assert!(cache.contains(r"c+"));
287    }
288
289    #[test]
290    fn cache_invalid_pattern() {
291        let cache = RegexCache::new(10);
292        let result = cache.get_or_compile(r"[invalid");
293        assert!(result.is_err());
294    }
295
296    #[test]
297    fn global_cache() {
298        let r1 = get_regex(r"\w+").unwrap();
299        let r2 = get_regex(r"\w+").unwrap();
300        assert!(Arc::ptr_eq(&r1, &r2));
301    }
302
303    #[test]
304    fn cache_stats_tracking() {
305        let cache = RegexCache::new(10);
306
307        // Initial state
308        let stats = cache.stats();
309        assert_eq!(stats.size, 0);
310        assert_eq!(stats.total_hits, 0);
311        assert_eq!(stats.total_misses, 0);
312
313        // First access (miss)
314        cache.get_or_compile(r"\d+").unwrap();
315        assert_eq!(cache.total_misses(), 1);
316        assert_eq!(cache.total_hits(), 0);
317
318        // Second access (hit)
319        cache.get_or_compile(r"\d+").unwrap();
320        assert_eq!(cache.total_misses(), 1);
321        assert_eq!(cache.total_hits(), 1);
322
323        // Third access (hit)
324        cache.get_or_compile(r"\d+").unwrap();
325        assert_eq!(cache.total_hits(), 2);
326
327        // New pattern (miss)
328        cache.get_or_compile(r"\w+").unwrap();
329        assert_eq!(cache.total_misses(), 2);
330
331        // Check hit rate (2 hits out of 4 total = 0.5)
332        let hit_rate = cache.hit_rate();
333        assert!((hit_rate - 0.5).abs() < 0.001);
334    }
335
336    #[test]
337    fn cache_stats_hit_rate_empty() {
338        let cache = RegexCache::new(10);
339        // Empty cache should return 1.0 hit rate (no failures yet)
340        assert!((cache.hit_rate() - 1.0).abs() < 0.001);
341    }
342}