1use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8use std::time::Duration;
9
10use dashmap::DashMap;
11use parking_lot::Mutex;
12use tokio::sync::Notify;
13
14use crate::error::Result;
15use crate::storage::{current_timestamp_ms, Storage, StorageEntry};
16
17#[derive(Debug, Clone)]
19pub enum GcInterval {
20 Requests(u64),
22 Duration(Duration),
24 Manual,
26}
27
28impl Default for GcInterval {
29 fn default() -> Self {
30 Self::Requests(10000)
31 }
32}
33
34#[derive(Debug, Clone)]
36pub struct GcConfig {
37 pub interval: GcInterval,
39 pub max_age: Duration,
41}
42
43impl Default for GcConfig {
44 fn default() -> Self {
45 Self {
46 interval: GcInterval::default(),
47 max_age: Duration::from_secs(3600),
48 }
49 }
50}
51
52impl GcConfig {
53 pub fn on_requests(count: u64) -> Self {
55 Self {
56 interval: GcInterval::Requests(count),
57 ..Default::default()
58 }
59 }
60
61 pub fn on_duration(interval: Duration) -> Self {
63 Self {
64 interval: GcInterval::Duration(interval),
65 ..Default::default()
66 }
67 }
68
69 pub fn manual() -> Self {
71 Self {
72 interval: GcInterval::Manual,
73 ..Default::default()
74 }
75 }
76
77 pub fn with_max_age(mut self, max_age: Duration) -> Self {
79 self.max_age = max_age;
80 self
81 }
82}
83
84#[derive(Debug, Clone)]
86struct InternalEntry {
87 entry: StorageEntry,
88 expires_at: u64,
89}
90
91pub struct MemoryStorage {
113 data: DashMap<String, InternalEntry>,
114 gc_config: GcConfig,
115 request_count: AtomicU64,
116 #[allow(dead_code)]
117 last_gc: AtomicU64,
118 gc_lock: Mutex<()>,
119 shutdown: Arc<Notify>,
120}
121
122impl std::fmt::Debug for MemoryStorage {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 f.debug_struct("MemoryStorage")
125 .field("entries", &self.data.len())
126 .field("gc_config", &self.gc_config)
127 .finish()
128 }
129}
130
131impl Default for MemoryStorage {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137impl MemoryStorage {
138 pub fn new() -> Self {
140 Self::with_gc(GcConfig::default())
141 }
142
143 pub fn with_gc(gc_config: GcConfig) -> Self {
145 let storage = Self {
146 data: DashMap::new(),
147 gc_config: gc_config.clone(),
148 request_count: AtomicU64::new(0),
149 last_gc: AtomicU64::new(current_timestamp_ms()),
150 gc_lock: Mutex::new(()),
151 shutdown: Arc::new(Notify::new()),
152 };
153
154 if let GcInterval::Duration(interval) = gc_config.interval {
156 storage.start_gc_task(interval);
157 }
158
159 storage
160 }
161
162 fn start_gc_task(&self, interval: Duration) {
164 let data = self.data.clone();
165 let max_age = self.gc_config.max_age;
166 let shutdown = self.shutdown.clone();
167
168 tokio::spawn(async move {
169 loop {
170 tokio::select! {
171 _ = tokio::time::sleep(interval) => {
172 run_gc_on_map(&data, max_age);
173 }
174 _ = shutdown.notified() => {
175 break;
176 }
177 }
178 }
179 });
180 }
181
182 pub async fn run_gc(&self) {
184 run_gc_on_map(&self.data, self.gc_config.max_age);
185 }
186
187 pub fn len(&self) -> usize {
189 self.data.len()
190 }
191
192 pub fn is_empty(&self) -> bool {
194 self.data.is_empty()
195 }
196
197 pub fn clear(&self) {
199 self.data.clear();
200 }
201
202 fn maybe_run_gc(&self) {
204 if let GcInterval::Requests(threshold) = self.gc_config.interval {
205 let count = self.request_count.fetch_add(1, Ordering::Relaxed);
206 if count.is_multiple_of(threshold) && count > 0 {
207 if let Some(_guard) = self.gc_lock.try_lock() {
209 run_gc_on_map(&self.data, self.gc_config.max_age);
210 }
211 }
212 }
213 }
214}
215
216impl Drop for MemoryStorage {
217 fn drop(&mut self) {
218 self.shutdown.notify_waiters();
219 }
220}
221
222fn run_gc_on_map(data: &DashMap<String, InternalEntry>, max_age: Duration) {
224 let now = current_timestamp_ms();
225 let max_age_ms = max_age.as_millis() as u64;
226 let cutoff = now.saturating_sub(max_age_ms);
227
228 data.retain(|_, entry| {
229 entry.expires_at > now || entry.entry.last_update > cutoff
231 });
232}
233
234impl Storage for MemoryStorage {
235 async fn get(&self, key: &str) -> Result<Option<StorageEntry>> {
236 self.maybe_run_gc();
237
238 let now = current_timestamp_ms();
239 if let Some(internal) = self.data.get(key) {
240 if internal.expires_at > now {
241 return Ok(Some(internal.entry.clone()));
242 }
243 drop(internal);
245 self.data.remove(key);
246 }
247 Ok(None)
248 }
249
250 async fn set(&self, key: &str, entry: StorageEntry, ttl: Duration) -> Result<()> {
251 self.maybe_run_gc();
252
253 let expires_at = current_timestamp_ms() + ttl.as_millis() as u64;
254 self.data.insert(
255 key.to_string(),
256 InternalEntry { entry, expires_at },
257 );
258 Ok(())
259 }
260
261 async fn delete(&self, key: &str) -> Result<()> {
262 self.data.remove(key);
263 Ok(())
264 }
265
266 async fn increment(
267 &self,
268 key: &str,
269 delta: u64,
270 window_start: u64,
271 ttl: Duration,
272 ) -> Result<u64> {
273 self.maybe_run_gc();
274
275 let expires_at = current_timestamp_ms() + ttl.as_millis() as u64;
276 let now = current_timestamp_ms();
277
278 let new_count = self.data
279 .entry(key.to_string())
280 .and_modify(|internal| {
281 if internal.entry.window_start != window_start {
283 internal.entry.prev_count = Some(internal.entry.count);
285 internal.entry.count = delta;
286 internal.entry.window_start = window_start;
287 } else {
288 internal.entry.count += delta;
289 }
290 internal.entry.last_update = now;
291 internal.expires_at = expires_at;
292 })
293 .or_insert_with(|| InternalEntry {
294 entry: StorageEntry::new(delta, window_start).set_last_update(now),
295 expires_at,
296 })
297 .entry
298 .count;
299
300 Ok(new_count)
301 }
302
303 async fn execute_atomic<F, T>(&self, key: &str, ttl: Duration, operation: F) -> Result<T>
304 where
305 F: FnOnce(Option<StorageEntry>) -> (StorageEntry, T) + Send,
306 T: Send,
307 {
308 self.maybe_run_gc();
309
310 let expires_at = current_timestamp_ms() + ttl.as_millis() as u64;
311 let now = current_timestamp_ms();
312
313 let current = self.data.get(key).and_then(|internal| {
315 if internal.expires_at > now {
316 Some(internal.entry.clone())
317 } else {
318 None
319 }
320 });
321
322 let (new_entry, result) = operation(current);
324
325 self.data.insert(
327 key.to_string(),
328 InternalEntry {
329 entry: new_entry,
330 expires_at,
331 },
332 );
333
334 Ok(result)
335 }
336
337 async fn compare_and_swap(
338 &self,
339 key: &str,
340 expected: Option<&StorageEntry>,
341 new: StorageEntry,
342 ttl: Duration,
343 ) -> Result<bool> {
344 self.maybe_run_gc();
345
346 let expires_at = current_timestamp_ms() + ttl.as_millis() as u64;
347 let now = current_timestamp_ms();
348
349 let current = self.data.get(key).and_then(|internal| {
351 if internal.expires_at > now {
352 Some(internal.entry.clone())
353 } else {
354 None
355 }
356 });
357
358 let matches = match (expected, ¤t) {
360 (None, None) => true,
361 (Some(exp), Some(cur)) => exp == cur,
362 _ => false,
363 };
364
365 if matches {
366 self.data.insert(
367 key.to_string(),
368 InternalEntry {
369 entry: new,
370 expires_at,
371 },
372 );
373 Ok(true)
374 } else {
375 Ok(false)
376 }
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[tokio::test]
385 async fn test_memory_storage_basic() {
386 let storage = MemoryStorage::new();
387
388 let entry = StorageEntry::new(5, 1000);
389 storage.set("key1", entry.clone(), Duration::from_secs(60)).await.unwrap();
390
391 let result = storage.get("key1").await.unwrap();
392 assert_eq!(result, Some(entry));
393 }
394
395 #[tokio::test]
396 async fn test_memory_storage_expiration() {
397 let storage = MemoryStorage::new();
398
399 let entry = StorageEntry::new(5, 1000);
400 storage.set("key1", entry, Duration::from_millis(10)).await.unwrap();
401
402 tokio::time::sleep(Duration::from_millis(20)).await;
404
405 let result = storage.get("key1").await.unwrap();
406 assert!(result.is_none());
407 }
408
409 #[tokio::test]
410 async fn test_memory_storage_increment() {
411 let storage = MemoryStorage::new();
412
413 let count = storage.increment("key1", 1, 1000, Duration::from_secs(60)).await.unwrap();
414 assert_eq!(count, 1);
415
416 let count = storage.increment("key1", 1, 1000, Duration::from_secs(60)).await.unwrap();
417 assert_eq!(count, 2);
418
419 let count = storage.increment("key1", 1, 2000, Duration::from_secs(60)).await.unwrap();
421 assert_eq!(count, 1);
422
423 let entry = storage.get("key1").await.unwrap().unwrap();
425 assert_eq!(entry.prev_count, Some(2));
426 }
427
428 #[tokio::test]
429 async fn test_memory_storage_execute_atomic() {
430 let storage = MemoryStorage::new();
431
432 let result = storage
433 .execute_atomic("key1", Duration::from_secs(60), |current| {
434 let count = current.map(|e| e.count).unwrap_or(0);
435 let new_entry = StorageEntry::new(count + 1, 1000);
436 (new_entry, count + 1)
437 })
438 .await
439 .unwrap();
440
441 assert_eq!(result, 1);
442
443 let result = storage
444 .execute_atomic("key1", Duration::from_secs(60), |current| {
445 let count = current.map(|e| e.count).unwrap_or(0);
446 let new_entry = StorageEntry::new(count + 1, 1000);
447 (new_entry, count + 1)
448 })
449 .await
450 .unwrap();
451
452 assert_eq!(result, 2);
453 }
454
455 #[tokio::test]
456 async fn test_memory_storage_cas() {
457 let storage = MemoryStorage::new();
458
459 let entry = StorageEntry::new(1, 1000);
461 let success = storage
462 .compare_and_swap("key1", None, entry.clone(), Duration::from_secs(60))
463 .await
464 .unwrap();
465 assert!(success);
466
467 let wrong = StorageEntry::new(999, 1000);
469 let entry2 = StorageEntry::new(2, 1000);
470 let success = storage
471 .compare_and_swap("key1", Some(&wrong), entry2.clone(), Duration::from_secs(60))
472 .await
473 .unwrap();
474 assert!(!success);
475
476 let success = storage
478 .compare_and_swap("key1", Some(&entry), entry2.clone(), Duration::from_secs(60))
479 .await
480 .unwrap();
481 assert!(success);
482 }
483
484 #[tokio::test]
485 async fn test_gc_config() {
486 let config = GcConfig::on_requests(1000)
487 .with_max_age(Duration::from_secs(3600));
488
489 assert!(matches!(config.interval, GcInterval::Requests(1000)));
490 assert_eq!(config.max_age, Duration::from_secs(3600));
491 }
492}