Skip to main content

skg_state_memory/
lib.rs

1#![deny(missing_docs)]
2//! In-memory implementation of layer0's StateStore trait.
3//!
4//! Uses a `HashMap` behind a `RwLock` for concurrent access.
5//! Scopes are serialized to strings for use as key prefixes,
6//! providing full scope isolation. Supports optional LRU eviction
7//! via [`MemoryStore::bounded`] and basic case-insensitive substring search.
8
9use async_trait::async_trait;
10use layer0::effect::Scope;
11use layer0::error::StateError;
12use layer0::state::{SearchResult, StateStore, StoreOptions};
13use std::collections::{HashMap, HashSet};
14use tokio::sync::RwLock;
15
16/// In-memory state store backed by a `HashMap` behind a `RwLock`.
17///
18/// Suitable for testing, prototyping, and single-process use cases
19/// where persistence across restarts is not required.
20///
21/// Create an unbounded store with [`MemoryStore::new`] or a capacity-limited
22/// LRU store with [`MemoryStore::bounded`].
23pub struct MemoryStore {
24    data: RwLock<HashMap<String, serde_json::Value>>,
25    transient: RwLock<HashMap<String, serde_json::Value>>,
26    capacity: Option<usize>,
27    /// Composite keys ordered by last access, least-recently used at front.
28    access_order: RwLock<Vec<String>>,
29    /// Composite keys marked durable — never evicted by LRU.
30    durable_keys: RwLock<HashSet<String>>,
31}
32
33impl MemoryStore {
34    /// Create a new empty in-memory store with no eviction limit.
35    pub fn new() -> Self {
36        Self {
37            data: RwLock::new(HashMap::new()),
38            transient: RwLock::new(HashMap::new()),
39            capacity: None,
40            access_order: RwLock::new(Vec::new()),
41            durable_keys: RwLock::new(HashSet::new()),
42        }
43    }
44
45    /// Create a bounded in-memory store that evicts least-recently-used entries
46    /// when the entry count exceeds `capacity`.
47    ///
48    /// Reads and writes both count as "use" for LRU ordering.
49    /// Pinned scope entries (written via `write_hinted` with `Lifetime::Durable`)
50    /// are never evicted.
51    pub fn bounded(capacity: usize) -> Self {
52        Self {
53            data: RwLock::new(HashMap::new()),
54            transient: RwLock::new(HashMap::new()),
55            capacity: Some(capacity),
56            access_order: RwLock::new(Vec::new()),
57            durable_keys: RwLock::new(HashSet::new()),
58        }
59    }
60
61    /// Insert or update `ck` in the data map, updating LRU tracking and evicting
62    /// if the store is bounded and over capacity.
63    ///
64    /// `is_durable` marks the key as non-evictable. Transient entries bypass this
65    /// path entirely (they go to the separate `transient` map).
66    ///
67    /// Lock order: `data` → `access_order` → `durable_keys`.
68    async fn write_inner(&self, ck: String, value: serde_json::Value, is_durable: bool) {
69        let mut data = self.data.write().await;
70        let mut order = self.access_order.write().await;
71        let mut durable = self.durable_keys.write().await;
72
73        if is_durable {
74            durable.insert(ck.clone());
75        }
76
77        // Remove any existing position, then push to back (most-recently used).
78        order.retain(|k| k != &ck);
79        order.push(ck.clone());
80        data.insert(ck, value);
81
82        // Evict least-recently-used non-durable entries until within capacity.
83        if let Some(cap) = self.capacity {
84            while data.len() > cap {
85                // Find the front-most key that is not durable.
86                let evict_idx = order.iter().position(|k| !durable.contains(k));
87                match evict_idx {
88                    Some(idx) => {
89                        let evict_ck = order.remove(idx);
90                        data.remove(&evict_ck);
91                    }
92                    // All remaining keys are durable — cannot evict further.
93                    None => break,
94                }
95            }
96        }
97    }
98}
99
100impl Default for MemoryStore {
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106/// Build a composite key from scope + key to ensure isolation.
107fn composite_key(scope: &Scope, key: &str) -> String {
108    let scope_str = serde_json::to_string(scope).unwrap_or_else(|_| "unknown".to_string());
109    format!("{scope_str}\0{key}")
110}
111
112/// Extract the user-facing key from a composite key, if it belongs to the given scope.
113fn extract_key<'a>(composite: &'a str, scope_prefix: &str) -> Option<&'a str> {
114    composite
115        .strip_prefix(scope_prefix)
116        .and_then(|rest| rest.strip_prefix('\0'))
117}
118
119#[async_trait]
120impl StateStore for MemoryStore {
121    async fn read(
122        &self,
123        scope: &Scope,
124        key: &str,
125    ) -> Result<Option<serde_json::Value>, StateError> {
126        let ck = composite_key(scope, key);
127        // Drop the read lock before acquiring the write lock on access_order
128        // to avoid holding two locks simultaneously (data.read + order.write).
129        let value = self.data.read().await.get(&ck).cloned();
130        if value.is_some() {
131            let mut order = self.access_order.write().await;
132            order.retain(|k| k != &ck);
133            order.push(ck);
134        }
135        Ok(value)
136    }
137
138    async fn write(
139        &self,
140        scope: &Scope,
141        key: &str,
142        value: serde_json::Value,
143    ) -> Result<(), StateError> {
144        let ck = composite_key(scope, key);
145        self.write_inner(ck, value, false).await;
146        Ok(())
147    }
148
149    async fn delete(&self, scope: &Scope, key: &str) -> Result<(), StateError> {
150        let ck = composite_key(scope, key);
151        self.data.write().await.remove(&ck);
152        self.access_order.write().await.retain(|k| k != &ck);
153        self.durable_keys.write().await.remove(&ck);
154        Ok(())
155    }
156
157    async fn list(&self, scope: &Scope, prefix: &str) -> Result<Vec<String>, StateError> {
158        let scope_prefix = serde_json::to_string(scope).unwrap_or_else(|_| "unknown".to_string());
159        let data = self.data.read().await;
160        let keys: Vec<String> = data
161            .keys()
162            .filter_map(|ck| {
163                extract_key(ck, &scope_prefix).and_then(|k| {
164                    if k.starts_with(prefix) {
165                        Some(k.to_string())
166                    } else {
167                        None
168                    }
169                })
170            })
171            .collect();
172        Ok(keys)
173    }
174
175    async fn search(
176        &self,
177        scope: &Scope,
178        query: &str,
179        limit: usize,
180    ) -> Result<Vec<SearchResult>, StateError> {
181        if query.is_empty() || limit == 0 {
182            return Ok(vec![]);
183        }
184
185        let scope_prefix = serde_json::to_string(scope).unwrap_or_else(|_| "unknown".to_string());
186        let query_lower = query.to_lowercase();
187
188        let data = self.data.read().await;
189        let mut results: Vec<SearchResult> = data
190            .iter()
191            .filter_map(|(ck, value)| {
192                let key = extract_key(ck, &scope_prefix)?;
193                let text = value.to_string();
194                let text_lower = text.to_lowercase();
195
196                let count = text_lower.matches(query_lower.as_str()).count();
197                if count == 0 {
198                    return None;
199                }
200
201                // Score: occurrence density — more occurrences in shorter text ranks higher.
202                let score = count as f64 / text_lower.len().max(1) as f64;
203                let mut result = SearchResult::new(key, score);
204                result.snippet = Some(if text.len() > 200 {
205                    format!("{}...", &text[..200])
206                } else {
207                    text
208                });
209                Some(result)
210            })
211            .collect();
212
213        results.sort_by(|a, b| {
214            b.score
215                .partial_cmp(&a.score)
216                .unwrap_or(std::cmp::Ordering::Equal)
217        });
218        results.truncate(limit);
219        Ok(results)
220    }
221
222    async fn write_hinted(
223        &self,
224        scope: &Scope,
225        key: &str,
226        value: serde_json::Value,
227        options: &StoreOptions,
228    ) -> Result<(), StateError> {
229        use layer0::state::Lifetime;
230        match options.lifetime {
231            Some(Lifetime::Transient) => {
232                let ck = composite_key(scope, key);
233                self.transient.write().await.insert(ck, value);
234            }
235            Some(Lifetime::Durable) => {
236                let ck = composite_key(scope, key);
237                self.write_inner(ck, value, true).await;
238            }
239            _ => {
240                self.write(scope, key, value).await?;
241            }
242        }
243        Ok(())
244    }
245
246    fn clear_transient(&self) {
247        // Use try_write; if the lock is contended, skip — best-effort clearing.
248        if let Ok(mut t) = self.transient.try_write() {
249            t.clear();
250        }
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use serde_json::json;
258
259    #[tokio::test]
260    async fn write_and_read() {
261        let store = MemoryStore::new();
262        let scope = Scope::Global;
263
264        store.write(&scope, "key1", json!("value1")).await.unwrap();
265        let val = store.read(&scope, "key1").await.unwrap();
266        assert_eq!(val, Some(json!("value1")));
267    }
268
269    #[tokio::test]
270    async fn read_nonexistent_returns_none() {
271        let store = MemoryStore::new();
272        let scope = Scope::Global;
273
274        let val = store.read(&scope, "missing").await.unwrap();
275        assert_eq!(val, None);
276    }
277
278    #[tokio::test]
279    async fn write_overwrites_existing() {
280        let store = MemoryStore::new();
281        let scope = Scope::Global;
282
283        store.write(&scope, "key1", json!("first")).await.unwrap();
284        store.write(&scope, "key1", json!("second")).await.unwrap();
285        let val = store.read(&scope, "key1").await.unwrap();
286        assert_eq!(val, Some(json!("second")));
287    }
288
289    #[tokio::test]
290    async fn delete_removes_key() {
291        let store = MemoryStore::new();
292        let scope = Scope::Global;
293
294        store.write(&scope, "key1", json!("value1")).await.unwrap();
295        store.delete(&scope, "key1").await.unwrap();
296        let val = store.read(&scope, "key1").await.unwrap();
297        assert_eq!(val, None);
298    }
299
300    #[tokio::test]
301    async fn delete_nonexistent_is_ok() {
302        let store = MemoryStore::new();
303        let scope = Scope::Global;
304
305        let result = store.delete(&scope, "missing").await;
306        assert!(result.is_ok());
307    }
308
309    #[tokio::test]
310    async fn list_keys_with_prefix() {
311        let store = MemoryStore::new();
312        let scope = Scope::Global;
313
314        store
315            .write(&scope, "user:name", json!("Alice"))
316            .await
317            .unwrap();
318        store.write(&scope, "user:age", json!(30)).await.unwrap();
319        store
320            .write(&scope, "system:version", json!("1.0"))
321            .await
322            .unwrap();
323
324        let mut keys = store.list(&scope, "user:").await.unwrap();
325        keys.sort();
326        assert_eq!(keys, vec!["user:age", "user:name"]);
327    }
328
329    #[tokio::test]
330    async fn list_empty_prefix_returns_all() {
331        let store = MemoryStore::new();
332        let scope = Scope::Global;
333
334        store.write(&scope, "a", json!(1)).await.unwrap();
335        store.write(&scope, "b", json!(2)).await.unwrap();
336
337        let keys = store.list(&scope, "").await.unwrap();
338        assert_eq!(keys.len(), 2);
339    }
340
341    #[tokio::test]
342    async fn scopes_are_isolated() {
343        let store = MemoryStore::new();
344        let global = Scope::Global;
345        let session = Scope::Session(layer0::SessionId::new("s1"));
346
347        store
348            .write(&global, "key", json!("global_val"))
349            .await
350            .unwrap();
351        store
352            .write(&session, "key", json!("session_val"))
353            .await
354            .unwrap();
355
356        let global_val = store.read(&global, "key").await.unwrap();
357        let session_val = store.read(&session, "key").await.unwrap();
358
359        assert_eq!(global_val, Some(json!("global_val")));
360        assert_eq!(session_val, Some(json!("session_val")));
361    }
362
363    #[tokio::test]
364    async fn search_returns_empty_on_no_match() {
365        let store = MemoryStore::new();
366        let scope = Scope::Global;
367
368        store
369            .write(&scope, "k1", json!("hello world"))
370            .await
371            .unwrap();
372        let results = store.search(&scope, "xyzzy", 10).await.unwrap();
373        assert!(results.is_empty());
374    }
375
376    #[test]
377    fn default_store_is_empty() {
378        let store = MemoryStore::default();
379        let _ = store; // Just verify it constructs
380    }
381
382    #[test]
383    fn memory_store_implements_state_store() {
384        fn _assert_state_store<T: StateStore>() {}
385        _assert_state_store::<MemoryStore>();
386    }
387
388    #[tokio::test]
389    async fn test_transient_write_not_durable() {
390        use layer0::state::{Lifetime, StoreOptions};
391
392        let store = MemoryStore::new();
393        let scope = Scope::Global;
394
395        // Write transient entry
396        let opts = StoreOptions {
397            lifetime: Some(Lifetime::Transient),
398            ..Default::default()
399        };
400        store
401            .write_hinted(&scope, "scratch", serde_json::json!("temp"), &opts)
402            .await
403            .unwrap();
404
405        // Transient entries are not visible via read()
406        let val = store.read(&scope, "scratch").await.unwrap();
407        assert_eq!(val, None, "transient entry must not be visible via read()");
408
409        // clear_transient is idempotent
410        store.clear_transient();
411        store.clear_transient();
412
413        // Write a durable entry
414        store
415            .write(&scope, "durable", serde_json::json!("persisted"))
416            .await
417            .unwrap();
418
419        // clear_transient does not touch durable storage
420        store.clear_transient();
421
422        let durable_val = store.read(&scope, "durable").await.unwrap();
423        assert_eq!(
424            durable_val,
425            Some(serde_json::json!("persisted")),
426            "durable entry must survive clear_transient()"
427        );
428    }
429
430    // ── LRU / bounded tests ──────────────────────────────────────────────────
431
432    #[tokio::test]
433    async fn bounded_evicts_oldest() {
434        let store = MemoryStore::bounded(3);
435        let scope = Scope::Global;
436
437        for k in ["a", "b", "c", "d", "e"] {
438            store.write(&scope, k, json!(k)).await.unwrap();
439        }
440
441        assert_eq!(
442            store.read(&scope, "a").await.unwrap(),
443            None,
444            "a should be evicted"
445        );
446        assert_eq!(
447            store.read(&scope, "b").await.unwrap(),
448            None,
449            "b should be evicted"
450        );
451        assert_eq!(store.read(&scope, "c").await.unwrap(), Some(json!("c")));
452        assert_eq!(store.read(&scope, "d").await.unwrap(), Some(json!("d")));
453        assert_eq!(store.read(&scope, "e").await.unwrap(), Some(json!("e")));
454    }
455
456    #[tokio::test]
457    async fn bounded_read_refreshes_lru() {
458        let store = MemoryStore::bounded(3);
459        let scope = Scope::Global;
460
461        store.write(&scope, "a", json!("a")).await.unwrap();
462        store.write(&scope, "b", json!("b")).await.unwrap();
463        store.write(&scope, "c", json!("c")).await.unwrap();
464
465        // Touch "a" — it becomes most-recently used; order becomes [b, c, a].
466        let _ = store.read(&scope, "a").await.unwrap();
467
468        // Write "d" — should evict "b" (now at front), not "a".
469        store.write(&scope, "d", json!("d")).await.unwrap();
470
471        assert_eq!(
472            store.read(&scope, "b").await.unwrap(),
473            None,
474            "b should be evicted"
475        );
476        assert!(
477            store.read(&scope, "a").await.unwrap().is_some(),
478            "a should survive"
479        );
480        assert!(
481            store.read(&scope, "c").await.unwrap().is_some(),
482            "c should survive"
483        );
484        assert!(
485            store.read(&scope, "d").await.unwrap().is_some(),
486            "d should survive"
487        );
488    }
489
490    #[tokio::test]
491    async fn bounded_unlimited_default() {
492        let store = MemoryStore::new();
493        let scope = Scope::Global;
494
495        for i in 0..100u32 {
496            store.write(&scope, &i.to_string(), json!(i)).await.unwrap();
497        }
498
499        for i in 0..100u32 {
500            assert!(
501                store.read(&scope, &i.to_string()).await.unwrap().is_some(),
502                "key {i} should not be evicted from unbounded store",
503            );
504        }
505    }
506
507    // ── Search tests ─────────────────────────────────────────────────────────
508
509    #[tokio::test]
510    async fn search_finds_substring() {
511        let store = MemoryStore::new();
512        let scope = Scope::Global;
513
514        store
515            .write(&scope, "k1", json!("hello world"))
516            .await
517            .unwrap();
518        store
519            .write(&scope, "k2", json!("goodbye world"))
520            .await
521            .unwrap();
522        store.write(&scope, "k3", json!(42)).await.unwrap();
523
524        let results = store.search(&scope, "world", 10).await.unwrap();
525        let keys: Vec<&str> = results.iter().map(|r| r.key.as_str()).collect();
526        assert!(keys.contains(&"k1"), "k1 should match");
527        assert!(keys.contains(&"k2"), "k2 should match");
528        assert!(!keys.contains(&"k3"), "k3 should not match");
529    }
530
531    #[tokio::test]
532    async fn search_case_insensitive() {
533        let store = MemoryStore::new();
534        let scope = Scope::Global;
535
536        store
537            .write(&scope, "k1", json!("Hello World"))
538            .await
539            .unwrap();
540        store.write(&scope, "k2", json!("HELLO")).await.unwrap();
541        store.write(&scope, "k3", json!("unrelated")).await.unwrap();
542
543        let results = store.search(&scope, "hello", 10).await.unwrap();
544        let keys: Vec<&str> = results.iter().map(|r| r.key.as_str()).collect();
545        assert!(keys.contains(&"k1"), "k1 should match case-insensitively");
546        assert!(keys.contains(&"k2"), "k2 should match case-insensitively");
547        assert!(!keys.contains(&"k3"), "k3 should not match");
548    }
549
550    #[tokio::test]
551    async fn search_respects_limit() {
552        let store = MemoryStore::new();
553        let scope = Scope::Global;
554
555        for i in 0..10u32 {
556            store
557                .write(&scope, &format!("k{i}"), json!("needle in haystack"))
558                .await
559                .unwrap();
560        }
561
562        let results = store.search(&scope, "needle", 3).await.unwrap();
563        assert_eq!(results.len(), 3, "results must be capped at the limit");
564    }
565}