rust_expect/expect/
cache.rs1use std::collections::HashMap;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::{Arc, RwLock};
9
10use regex::Regex;
11
12pub const DEFAULT_CACHE_SIZE: usize = 100;
14
15pub struct RegexCache {
19 cache: RwLock<LruCache>,
20 max_size: usize,
21 total_hits: AtomicUsize,
23 total_misses: AtomicUsize,
25}
26
27struct LruCache {
28 entries: HashMap<String, CacheEntry>,
29 order: Vec<String>,
30}
31
32struct CacheEntry {
33 regex: Arc<Regex>,
34 hits: AtomicUsize,
36}
37
38impl RegexCache {
39 #[must_use]
41 pub fn new(max_size: usize) -> Self {
42 Self {
43 cache: RwLock::new(LruCache {
44 entries: HashMap::with_capacity(max_size),
45 order: Vec::with_capacity(max_size),
46 }),
47 max_size,
48 total_hits: AtomicUsize::new(0),
49 total_misses: AtomicUsize::new(0),
50 }
51 }
52
53 #[must_use]
55 pub fn with_default_size() -> Self {
56 Self::new(DEFAULT_CACHE_SIZE)
57 }
58
59 pub fn get_or_compile(&self, pattern: &str) -> Result<Arc<Regex>, regex::Error> {
67 {
70 let cache = self
71 .cache
72 .read()
73 .unwrap_or_else(std::sync::PoisonError::into_inner);
74 if let Some(entry) = cache.entries.get(pattern) {
75 entry.hits.fetch_add(1, Ordering::Relaxed);
77 self.total_hits.fetch_add(1, Ordering::Relaxed);
78 return Ok(Arc::clone(&entry.regex));
79 }
80 }
81
82 self.total_misses.fetch_add(1, Ordering::Relaxed);
84
85 let regex = Regex::new(pattern)?;
87 let regex = Arc::new(regex);
88
89 {
91 let mut cache = self
92 .cache
93 .write()
94 .unwrap_or_else(std::sync::PoisonError::into_inner);
95
96 if let Some(entry) = cache.entries.get(pattern) {
98 entry.hits.fetch_add(1, Ordering::Relaxed);
100 return Ok(Arc::clone(&entry.regex));
101 }
102
103 if cache.entries.len() >= self.max_size
105 && let Some(oldest) = cache.order.first().cloned()
106 {
107 cache.entries.remove(&oldest);
108 cache.order.remove(0);
109 }
110
111 cache.entries.insert(
113 pattern.to_string(),
114 CacheEntry {
115 regex: Arc::clone(®ex),
116 hits: AtomicUsize::new(1), },
118 );
119 cache.order.push(pattern.to_string());
120 }
121
122 Ok(regex)
123 }
124
125 #[must_use]
127 pub fn contains(&self, pattern: &str) -> bool {
128 let cache = self
129 .cache
130 .read()
131 .unwrap_or_else(std::sync::PoisonError::into_inner);
132 cache.entries.contains_key(pattern)
133 }
134
135 #[must_use]
137 pub fn len(&self) -> usize {
138 let cache = self
139 .cache
140 .read()
141 .unwrap_or_else(std::sync::PoisonError::into_inner);
142 cache.entries.len()
143 }
144
145 #[must_use]
147 pub fn is_empty(&self) -> bool {
148 self.len() == 0
149 }
150
151 pub fn clear(&self) {
153 let mut cache = self
154 .cache
155 .write()
156 .unwrap_or_else(std::sync::PoisonError::into_inner);
157 cache.entries.clear();
158 cache.order.clear();
159 }
160
161 #[must_use]
163 pub const fn max_size(&self) -> usize {
164 self.max_size
165 }
166
167 #[must_use]
169 pub fn stats(&self) -> CacheStats {
170 let cache = self
171 .cache
172 .read()
173 .unwrap_or_else(std::sync::PoisonError::into_inner);
174
175 CacheStats {
176 size: cache.entries.len(),
177 max_size: self.max_size,
178 total_hits: self.total_hits.load(Ordering::Relaxed),
179 total_misses: self.total_misses.load(Ordering::Relaxed),
180 }
181 }
182
183 #[must_use]
185 pub fn total_hits(&self) -> usize {
186 self.total_hits.load(Ordering::Relaxed)
187 }
188
189 #[must_use]
191 pub fn total_misses(&self) -> usize {
192 self.total_misses.load(Ordering::Relaxed)
193 }
194
195 #[must_use]
199 pub fn hit_rate(&self) -> f64 {
200 let hits = self.total_hits.load(Ordering::Relaxed);
201 let misses = self.total_misses.load(Ordering::Relaxed);
202 let total = hits + misses;
203 if total == 0 {
204 1.0
205 } else {
206 hits as f64 / total as f64
207 }
208 }
209}
210
211#[derive(Debug, Clone, Copy, PartialEq, Eq)]
213pub struct CacheStats {
214 pub size: usize,
216 pub max_size: usize,
218 pub total_hits: usize,
220 pub total_misses: usize,
222}
223
224impl CacheStats {
225 #[must_use]
229 pub fn hit_rate(&self) -> f64 {
230 let total = self.total_hits + self.total_misses;
231 if total == 0 {
232 1.0
233 } else {
234 self.total_hits as f64 / total as f64
235 }
236 }
237}
238
239impl Default for RegexCache {
240 fn default() -> Self {
241 Self::with_default_size()
242 }
243}
244
245pub static GLOBAL_CACHE: std::sync::LazyLock<RegexCache> =
247 std::sync::LazyLock::new(RegexCache::with_default_size);
248
249pub fn get_regex(pattern: &str) -> Result<Arc<Regex>, regex::Error> {
255 GLOBAL_CACHE.get_or_compile(pattern)
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn cache_basic() {
264 let cache = RegexCache::new(10);
265
266 let r1 = cache.get_or_compile(r"\d+").unwrap();
267 let r2 = cache.get_or_compile(r"\d+").unwrap();
268
269 assert!(Arc::ptr_eq(&r1, &r2));
271 }
272
273 #[test]
274 fn cache_eviction() {
275 let cache = RegexCache::new(2);
276
277 cache.get_or_compile(r"a+").unwrap();
278 cache.get_or_compile(r"b+").unwrap();
279 assert_eq!(cache.len(), 2);
280
281 cache.get_or_compile(r"c+").unwrap();
283 assert_eq!(cache.len(), 2);
284 assert!(!cache.contains(r"a+"));
285 assert!(cache.contains(r"b+"));
286 assert!(cache.contains(r"c+"));
287 }
288
289 #[test]
290 fn cache_invalid_pattern() {
291 let cache = RegexCache::new(10);
292 let result = cache.get_or_compile(r"[invalid");
293 assert!(result.is_err());
294 }
295
296 #[test]
297 fn global_cache() {
298 let r1 = get_regex(r"\w+").unwrap();
299 let r2 = get_regex(r"\w+").unwrap();
300 assert!(Arc::ptr_eq(&r1, &r2));
301 }
302
303 #[test]
304 fn cache_stats_tracking() {
305 let cache = RegexCache::new(10);
306
307 let stats = cache.stats();
309 assert_eq!(stats.size, 0);
310 assert_eq!(stats.total_hits, 0);
311 assert_eq!(stats.total_misses, 0);
312
313 cache.get_or_compile(r"\d+").unwrap();
315 assert_eq!(cache.total_misses(), 1);
316 assert_eq!(cache.total_hits(), 0);
317
318 cache.get_or_compile(r"\d+").unwrap();
320 assert_eq!(cache.total_misses(), 1);
321 assert_eq!(cache.total_hits(), 1);
322
323 cache.get_or_compile(r"\d+").unwrap();
325 assert_eq!(cache.total_hits(), 2);
326
327 cache.get_or_compile(r"\w+").unwrap();
329 assert_eq!(cache.total_misses(), 2);
330
331 let hit_rate = cache.hit_rate();
333 assert!((hit_rate - 0.5).abs() < 0.001);
334 }
335
336 #[test]
337 fn cache_stats_hit_rate_empty() {
338 let cache = RegexCache::new(10);
339 assert!((cache.hit_rate() - 1.0).abs() < 0.001);
341 }
342}