skp_ratelimit/storage/
memory_gc.rs

1//! In-memory storage with automatic garbage collection.
2//!
3//! This storage backend uses `DashMap` for thread-safe concurrent access
4//! and includes configurable garbage collection to prevent memory growth.
5
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8use std::time::Duration;
9
10use dashmap::DashMap;
11use parking_lot::Mutex;
12use tokio::sync::Notify;
13
14use crate::error::Result;
15use crate::storage::{current_timestamp_ms, Storage, StorageEntry};
16
17/// Garbage collection interval configuration.
18#[derive(Debug, Clone)]
19pub enum GcInterval {
20    /// Run GC every N requests.
21    Requests(u64),
22    /// Run GC at fixed time intervals.
23    Duration(Duration),
24    /// Disable automatic GC.
25    Manual,
26}
27
28impl Default for GcInterval {
29    fn default() -> Self {
30        Self::Requests(10000)
31    }
32}
33
34/// Garbage collection configuration.
35#[derive(Debug, Clone)]
36pub struct GcConfig {
37    /// When to trigger GC.
38    pub interval: GcInterval,
39    /// Maximum age of entries before cleanup (default: 1 hour).
40    pub max_age: Duration,
41}
42
43impl Default for GcConfig {
44    fn default() -> Self {
45        Self {
46            interval: GcInterval::default(),
47            max_age: Duration::from_secs(3600),
48        }
49    }
50}
51
52impl GcConfig {
53    /// Create config with request-based GC.
54    pub fn on_requests(count: u64) -> Self {
55        Self {
56            interval: GcInterval::Requests(count),
57            ..Default::default()
58        }
59    }
60
61    /// Create config with time-based GC.
62    pub fn on_duration(interval: Duration) -> Self {
63        Self {
64            interval: GcInterval::Duration(interval),
65            ..Default::default()
66        }
67    }
68
69    /// Create config with manual GC only.
70    pub fn manual() -> Self {
71        Self {
72            interval: GcInterval::Manual,
73            ..Default::default()
74        }
75    }
76
77    /// Set the maximum age for entries.
78    pub fn with_max_age(mut self, max_age: Duration) -> Self {
79        self.max_age = max_age;
80        self
81    }
82}
83
84/// Internal entry with expiration tracking.
85#[derive(Debug, Clone)]
86struct InternalEntry {
87    entry: StorageEntry,
88    expires_at: u64,
89}
90
91/// In-memory storage with garbage collection.
92///
93/// Uses `DashMap` for thread-safe concurrent access and includes
94/// configurable garbage collection to prevent unbounded memory growth.
95///
96/// # Example
97///
98/// ```ignore
99/// use skp_ratelimit::storage::{MemoryStorage, GcConfig};
100/// use std::time::Duration;
101///
102/// // Default GC (every 10000 requests)
103/// let storage = MemoryStorage::new();
104///
105/// // Custom GC interval
106/// let storage = MemoryStorage::with_gc(GcConfig::on_duration(Duration::from_secs(60)));
107///
108/// // Manual GC only
109/// let storage = MemoryStorage::with_gc(GcConfig::manual());
110/// storage.run_gc().await;
111/// ```
112pub struct MemoryStorage {
113    data: DashMap<String, InternalEntry>,
114    gc_config: GcConfig,
115    request_count: AtomicU64,
116    #[allow(dead_code)]
117    last_gc: AtomicU64,
118    gc_lock: Mutex<()>,
119    shutdown: Arc<Notify>,
120}
121
122impl std::fmt::Debug for MemoryStorage {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        f.debug_struct("MemoryStorage")
125            .field("entries", &self.data.len())
126            .field("gc_config", &self.gc_config)
127            .finish()
128    }
129}
130
131impl Default for MemoryStorage {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137impl MemoryStorage {
138    /// Create a new memory storage with default GC configuration.
139    pub fn new() -> Self {
140        Self::with_gc(GcConfig::default())
141    }
142
143    /// Create a new memory storage with custom GC configuration.
144    pub fn with_gc(gc_config: GcConfig) -> Self {
145        let storage = Self {
146            data: DashMap::new(),
147            gc_config: gc_config.clone(),
148            request_count: AtomicU64::new(0),
149            last_gc: AtomicU64::new(current_timestamp_ms()),
150            gc_lock: Mutex::new(()),
151            shutdown: Arc::new(Notify::new()),
152        };
153
154        // Start background GC task if duration-based
155        if let GcInterval::Duration(interval) = gc_config.interval {
156            storage.start_gc_task(interval);
157        }
158
159        storage
160    }
161
162    /// Start background GC task.
163    fn start_gc_task(&self, interval: Duration) {
164        let data = self.data.clone();
165        let max_age = self.gc_config.max_age;
166        let shutdown = self.shutdown.clone();
167
168        tokio::spawn(async move {
169            loop {
170                tokio::select! {
171                    _ = tokio::time::sleep(interval) => {
172                        run_gc_on_map(&data, max_age);
173                    }
174                    _ = shutdown.notified() => {
175                        break;
176                    }
177                }
178            }
179        });
180    }
181
182    /// Manually trigger garbage collection.
183    pub async fn run_gc(&self) {
184        run_gc_on_map(&self.data, self.gc_config.max_age);
185    }
186
187    /// Get the number of entries currently stored.
188    pub fn len(&self) -> usize {
189        self.data.len()
190    }
191
192    /// Check if the storage is empty.
193    pub fn is_empty(&self) -> bool {
194        self.data.is_empty()
195    }
196
197    /// Clear all entries.
198    pub fn clear(&self) {
199        self.data.clear();
200    }
201
202    /// Check if GC should run and run it if needed.
203    fn maybe_run_gc(&self) {
204        if let GcInterval::Requests(threshold) = self.gc_config.interval {
205            let count = self.request_count.fetch_add(1, Ordering::Relaxed);
206            if count.is_multiple_of(threshold) && count > 0 {
207                // Try to acquire GC lock (non-blocking)
208                if let Some(_guard) = self.gc_lock.try_lock() {
209                    run_gc_on_map(&self.data, self.gc_config.max_age);
210                }
211            }
212        }
213    }
214}
215
216impl Drop for MemoryStorage {
217    fn drop(&mut self) {
218        self.shutdown.notify_waiters();
219    }
220}
221
222/// Run garbage collection on a DashMap.
223fn run_gc_on_map(data: &DashMap<String, InternalEntry>, max_age: Duration) {
224    let now = current_timestamp_ms();
225    let max_age_ms = max_age.as_millis() as u64;
226    let cutoff = now.saturating_sub(max_age_ms);
227
228    data.retain(|_, entry| {
229        // Keep if not expired and not too old
230        entry.expires_at > now || entry.entry.last_update > cutoff
231    });
232}
233
234impl Storage for MemoryStorage {
235    async fn get(&self, key: &str) -> Result<Option<StorageEntry>> {
236        self.maybe_run_gc();
237
238        let now = current_timestamp_ms();
239        if let Some(internal) = self.data.get(key) {
240            if internal.expires_at > now {
241                return Ok(Some(internal.entry.clone()));
242            }
243            // Entry expired, remove it
244            drop(internal);
245            self.data.remove(key);
246        }
247        Ok(None)
248    }
249
250    async fn set(&self, key: &str, entry: StorageEntry, ttl: Duration) -> Result<()> {
251        self.maybe_run_gc();
252
253        let expires_at = current_timestamp_ms() + ttl.as_millis() as u64;
254        self.data.insert(
255            key.to_string(),
256            InternalEntry { entry, expires_at },
257        );
258        Ok(())
259    }
260
261    async fn delete(&self, key: &str) -> Result<()> {
262        self.data.remove(key);
263        Ok(())
264    }
265
266    async fn increment(
267        &self,
268        key: &str,
269        delta: u64,
270        window_start: u64,
271        ttl: Duration,
272    ) -> Result<u64> {
273        self.maybe_run_gc();
274
275        let expires_at = current_timestamp_ms() + ttl.as_millis() as u64;
276        let now = current_timestamp_ms();
277
278        let new_count = self.data
279            .entry(key.to_string())
280            .and_modify(|internal| {
281                // Check if we're in a new window
282                if internal.entry.window_start != window_start {
283                    // Store old count as prev_count for sliding window
284                    internal.entry.prev_count = Some(internal.entry.count);
285                    internal.entry.count = delta;
286                    internal.entry.window_start = window_start;
287                } else {
288                    internal.entry.count += delta;
289                }
290                internal.entry.last_update = now;
291                internal.expires_at = expires_at;
292            })
293            .or_insert_with(|| InternalEntry {
294                entry: StorageEntry::new(delta, window_start).set_last_update(now),
295                expires_at,
296            })
297            .entry
298            .count;
299
300        Ok(new_count)
301    }
302
303    async fn execute_atomic<F, T>(&self, key: &str, ttl: Duration, operation: F) -> Result<T>
304    where
305        F: FnOnce(Option<StorageEntry>) -> (StorageEntry, T) + Send,
306        T: Send,
307    {
308        self.maybe_run_gc();
309
310        let expires_at = current_timestamp_ms() + ttl.as_millis() as u64;
311        let now = current_timestamp_ms();
312
313        // Get current entry
314        let current = self.data.get(key).and_then(|internal| {
315            if internal.expires_at > now {
316                Some(internal.entry.clone())
317            } else {
318                None
319            }
320        });
321
322        // Execute operation
323        let (new_entry, result) = operation(current);
324
325        // Store new entry
326        self.data.insert(
327            key.to_string(),
328            InternalEntry {
329                entry: new_entry,
330                expires_at,
331            },
332        );
333
334        Ok(result)
335    }
336
337    async fn compare_and_swap(
338        &self,
339        key: &str,
340        expected: Option<&StorageEntry>,
341        new: StorageEntry,
342        ttl: Duration,
343    ) -> Result<bool> {
344        self.maybe_run_gc();
345
346        let expires_at = current_timestamp_ms() + ttl.as_millis() as u64;
347        let now = current_timestamp_ms();
348
349        // Get current entry
350        let current = self.data.get(key).and_then(|internal| {
351            if internal.expires_at > now {
352                Some(internal.entry.clone())
353            } else {
354                None
355            }
356        });
357
358        // Check if expected matches current
359        let matches = match (expected, &current) {
360            (None, None) => true,
361            (Some(exp), Some(cur)) => exp == cur,
362            _ => false,
363        };
364
365        if matches {
366            self.data.insert(
367                key.to_string(),
368                InternalEntry {
369                    entry: new,
370                    expires_at,
371                },
372            );
373            Ok(true)
374        } else {
375            Ok(false)
376        }
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[tokio::test]
385    async fn test_memory_storage_basic() {
386        let storage = MemoryStorage::new();
387        
388        let entry = StorageEntry::new(5, 1000);
389        storage.set("key1", entry.clone(), Duration::from_secs(60)).await.unwrap();
390        
391        let result = storage.get("key1").await.unwrap();
392        assert_eq!(result, Some(entry));
393    }
394
395    #[tokio::test]
396    async fn test_memory_storage_expiration() {
397        let storage = MemoryStorage::new();
398        
399        let entry = StorageEntry::new(5, 1000);
400        storage.set("key1", entry, Duration::from_millis(10)).await.unwrap();
401        
402        // Wait for expiration
403        tokio::time::sleep(Duration::from_millis(20)).await;
404        
405        let result = storage.get("key1").await.unwrap();
406        assert!(result.is_none());
407    }
408
409    #[tokio::test]
410    async fn test_memory_storage_increment() {
411        let storage = MemoryStorage::new();
412        
413        let count = storage.increment("key1", 1, 1000, Duration::from_secs(60)).await.unwrap();
414        assert_eq!(count, 1);
415        
416        let count = storage.increment("key1", 1, 1000, Duration::from_secs(60)).await.unwrap();
417        assert_eq!(count, 2);
418        
419        // New window
420        let count = storage.increment("key1", 1, 2000, Duration::from_secs(60)).await.unwrap();
421        assert_eq!(count, 1);
422        
423        // Check prev_count is stored
424        let entry = storage.get("key1").await.unwrap().unwrap();
425        assert_eq!(entry.prev_count, Some(2));
426    }
427
428    #[tokio::test]
429    async fn test_memory_storage_execute_atomic() {
430        let storage = MemoryStorage::new();
431        
432        let result = storage
433            .execute_atomic("key1", Duration::from_secs(60), |current| {
434                let count = current.map(|e| e.count).unwrap_or(0);
435                let new_entry = StorageEntry::new(count + 1, 1000);
436                (new_entry, count + 1)
437            })
438            .await
439            .unwrap();
440        
441        assert_eq!(result, 1);
442        
443        let result = storage
444            .execute_atomic("key1", Duration::from_secs(60), |current| {
445                let count = current.map(|e| e.count).unwrap_or(0);
446                let new_entry = StorageEntry::new(count + 1, 1000);
447                (new_entry, count + 1)
448            })
449            .await
450            .unwrap();
451        
452        assert_eq!(result, 2);
453    }
454
455    #[tokio::test]
456    async fn test_memory_storage_cas() {
457        let storage = MemoryStorage::new();
458        
459        // CAS on non-existent key
460        let entry = StorageEntry::new(1, 1000);
461        let success = storage
462            .compare_and_swap("key1", None, entry.clone(), Duration::from_secs(60))
463            .await
464            .unwrap();
465        assert!(success);
466        
467        // CAS with wrong expected value
468        let wrong = StorageEntry::new(999, 1000);
469        let entry2 = StorageEntry::new(2, 1000);
470        let success = storage
471            .compare_and_swap("key1", Some(&wrong), entry2.clone(), Duration::from_secs(60))
472            .await
473            .unwrap();
474        assert!(!success);
475        
476        // CAS with correct expected value
477        let success = storage
478            .compare_and_swap("key1", Some(&entry), entry2.clone(), Duration::from_secs(60))
479            .await
480            .unwrap();
481        assert!(success);
482    }
483
484    #[tokio::test]
485    async fn test_gc_config() {
486        let config = GcConfig::on_requests(1000)
487            .with_max_age(Duration::from_secs(3600));
488        
489        assert!(matches!(config.interval, GcInterval::Requests(1000)));
490        assert_eq!(config.max_age, Duration::from_secs(3600));
491    }
492}