Skip to main content

sockudo_cache/
memory_cache_manager.rs

1use async_trait::async_trait;
2use moka::future::Cache;
3use sockudo_core::cache::{CacheManager, CacheScanPage};
4use sockudo_core::error::Result;
5use sockudo_core::options::MemoryCacheOptions;
6use std::time::Duration;
7
8/// A Memory-based implementation of the CacheManager trait using Moka.
9#[derive(Clone)]
10pub struct MemoryCacheManager {
11    /// Moka async cache for storing entries. Key and Value are Strings.
12    cache: Cache<String, String, ahash::RandomState>,
13    /// Configuration options for this cache instance.
14    options: MemoryCacheOptions,
15    /// Prefix for all keys in this cache instance.
16    prefix: String,
17}
18
19impl MemoryCacheManager {
20    /// Creates a new Memory cache manager with Moka configuration.
21    pub fn new(prefix: String, options: MemoryCacheOptions) -> Self {
22        let cache_builder = Cache::builder()
23            .max_capacity(options.max_capacity)
24            .name(format!("sockudo-memory-cache-{prefix}").as_str());
25
26        let cache = if options.ttl > 0 {
27            cache_builder.time_to_live(Duration::from_secs(options.ttl))
28        } else {
29            cache_builder
30        }
31        .build_with_hasher(ahash::RandomState::new());
32
33        Self {
34            cache,
35            options,
36            prefix,
37        }
38    }
39
40    /// Get the prefixed key.
41    fn prefixed_key(&self, key: &str) -> String {
42        format!("{}:{}", self.prefix, key)
43    }
44}
45
46#[async_trait]
47impl CacheManager for MemoryCacheManager {
48    async fn has(&self, key: &str) -> Result<bool> {
49        let prefixed_key = self.prefixed_key(key);
50        let exists = self.cache.get(&prefixed_key).await.is_some();
51        Ok(exists)
52    }
53
54    async fn get(&self, key: &str) -> Result<Option<String>> {
55        let prefixed_key = self.prefixed_key(key);
56        Ok(self.cache.get(&prefixed_key).await)
57    }
58
59    async fn set(&self, key: &str, value: &str, _ttl_seconds: u64) -> Result<()> {
60        let prefixed_key = self.prefixed_key(key);
61        let value_string = value.to_string();
62
63        self.cache.insert(prefixed_key, value_string).await;
64        Ok(())
65    }
66
67    async fn remove(&self, key: &str) -> Result<()> {
68        let prefixed_key = self.prefixed_key(key);
69        self.cache.invalidate(&prefixed_key).await;
70        Ok(())
71    }
72
73    async fn disconnect(&self) -> Result<()> {
74        self.cache.invalidate_all();
75        Ok(())
76    }
77
78    async fn ttl(&self, key: &str) -> Result<Option<Duration>> {
79        let prefixed_key = self.prefixed_key(key);
80        if self.cache.contains_key(&prefixed_key) {
81            if self.options.ttl > 0 {
82                Ok(Some(Duration::from_secs(self.options.ttl)))
83            } else {
84                Ok(None)
85            }
86        } else {
87            Ok(None)
88        }
89    }
90
91    async fn scan_prefix(&self, prefix: &str, limit: usize) -> Result<Vec<(String, String)>> {
92        if limit == 0 {
93            return Ok(Vec::new());
94        }
95
96        let mut entries = Vec::with_capacity(limit.min(64));
97        let cache_prefix = format!("{}:", self.prefix);
98        let prefix_len = cache_prefix.len();
99
100        for (key, value) in self.cache.iter() {
101            if entries.len() >= limit {
102                break;
103            }
104            if !key.starts_with(&cache_prefix) {
105                continue;
106            }
107            let unprefixed_key = &key[prefix_len..];
108            if unprefixed_key.starts_with(prefix) {
109                entries.push((unprefixed_key.to_string(), value.clone()));
110            }
111        }
112
113        Ok(entries)
114    }
115
116    async fn scan_prefix_page(
117        &self,
118        prefix: &str,
119        cursor: Option<String>,
120        limit: usize,
121    ) -> Result<CacheScanPage> {
122        if limit == 0 {
123            return Ok(CacheScanPage::default());
124        }
125
126        let cache_prefix = format!("{}:", self.prefix);
127        let prefix_len = cache_prefix.len();
128        let mut matching = self
129            .cache
130            .iter()
131            .filter_map(|(key, value)| {
132                if !key.starts_with(&cache_prefix) {
133                    return None;
134                }
135                let unprefixed_key = key[prefix_len..].to_string();
136                if unprefixed_key.starts_with(prefix) {
137                    Some((unprefixed_key, value))
138                } else {
139                    None
140                }
141            })
142            .collect::<Vec<_>>();
143        matching.sort_by(|left, right| left.0.cmp(&right.0));
144
145        let start = cursor
146            .as_deref()
147            .and_then(|cursor| matching.iter().position(|(key, _)| key.as_str() > cursor))
148            .unwrap_or(0);
149        let end = start.saturating_add(limit).min(matching.len());
150        let entries = matching[start..end].to_vec();
151        let next_cursor = if end < matching.len() {
152            entries.last().map(|(key, _)| key.clone())
153        } else {
154            None
155        };
156
157        Ok(CacheScanPage {
158            entries,
159            next_cursor,
160        })
161    }
162
163    async fn set_if_not_exists(&self, key: &str, value: &str, _ttl_seconds: u64) -> Result<bool> {
164        let prefixed_key = self.prefixed_key(key);
165        // Moka's `contains_key` + `insert` isn't truly atomic, but for in-memory
166        // single-process use this is sufficient since Tokio tasks on the same
167        // runtime don't preempt each other within a single .await-free block.
168        if self.cache.contains_key(&prefixed_key) {
169            Ok(false)
170        } else {
171            self.cache.insert(prefixed_key, value.to_string()).await;
172            Ok(true)
173        }
174    }
175
176    async fn increment_by(&self, key: &str, delta: i64, _ttl_seconds: u64) -> Result<i64> {
177        let prefixed_key = self.prefixed_key(key);
178        let entry = self
179            .cache
180            .entry(prefixed_key)
181            .and_upsert_with(|entry| {
182                let next = entry
183                    .and_then(|entry| entry.into_value().parse::<i64>().ok())
184                    .unwrap_or(0)
185                    .saturating_add(delta);
186                std::future::ready(next.to_string())
187            })
188            .await;
189        Ok(entry.into_value().parse::<i64>().unwrap_or(0))
190    }
191}
192
193impl MemoryCacheManager {
194    /// Delete a key from the cache.
195    pub async fn delete(&mut self, key: &str) -> Result<bool> {
196        let prefixed_key = self.prefixed_key(key);
197        if self.cache.contains_key(&prefixed_key) {
198            self.cache.invalidate(&prefixed_key).await;
199            Ok(true)
200        } else {
201            Ok(false)
202        }
203    }
204
205    /// Get multiple keys at once.
206    pub async fn get_many(&mut self, keys: &[&str]) -> Result<Vec<Option<String>>> {
207        let mut results = Vec::with_capacity(keys.len());
208        for &key in keys {
209            results.push(self.get(key).await?);
210        }
211        Ok(results)
212    }
213
214    /// Set multiple key-value pairs at once.
215    pub async fn set_many(&mut self, pairs: &[(&str, &str)], _ttl_seconds: u64) -> Result<()> {
216        for (key, value) in pairs {
217            let prefixed_key = self.prefixed_key(key);
218            let value_string = value.to_string();
219            self.cache.insert(prefixed_key, value_string).await;
220        }
221        Ok(())
222    }
223
224    /// Get all entries from the cache as (key, value, ttl) tuples.
225    /// Returns entries without the prefix.
226    ///
227    /// Note: Moka doesn't support per-entry TTL tracking, so this returns the
228    /// cache's default TTL for all entries. When syncing to another cache system,
229    /// this means all entries will get the same TTL, not their remaining time.
230    pub async fn get_all_entries(&self) -> Vec<(String, String, Option<Duration>)> {
231        let mut entries = Vec::new();
232        let prefix_len = self.prefix.len() + 1; // +1 for the colon separator
233
234        for (key, value) in self.cache.iter() {
235            if key.starts_with(&format!("{}:", self.prefix)) {
236                let unprefixed_key = key[prefix_len..].to_string();
237                let ttl = if self.options.ttl > 0 {
238                    Some(Duration::from_secs(self.options.ttl))
239                } else {
240                    None
241                };
242                entries.push((unprefixed_key, value.clone(), ttl));
243            }
244        }
245
246        entries
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use std::sync::Arc;
254
255    #[tokio::test]
256    async fn increment_by_serializes_concurrent_updates() {
257        let cache = Arc::new(MemoryCacheManager::new(
258            "test".to_string(),
259            MemoryCacheOptions {
260                ttl: 60,
261                cleanup_interval: 60,
262                max_capacity: 1_000,
263            },
264        ));
265
266        let handles = (0..128)
267            .map(|_| {
268                let cache = Arc::clone(&cache);
269                tokio::spawn(async move { cache.increment_by("counter", 1, 60).await })
270            })
271            .collect::<Vec<_>>();
272
273        for handle in handles {
274            handle.await.unwrap().unwrap();
275        }
276
277        assert_eq!(cache.get("counter").await.unwrap().as_deref(), Some("128"));
278    }
279}