1#![deny(missing_docs)]
2use async_trait::async_trait;
10use layer0::effect::Scope;
11use layer0::error::StateError;
12use layer0::state::{SearchResult, StateStore, StoreOptions};
13use std::collections::{HashMap, HashSet};
14use tokio::sync::RwLock;
15
16pub struct MemoryStore {
24 data: RwLock<HashMap<String, serde_json::Value>>,
25 transient: RwLock<HashMap<String, serde_json::Value>>,
26 capacity: Option<usize>,
27 access_order: RwLock<Vec<String>>,
29 durable_keys: RwLock<HashSet<String>>,
31}
32
33impl MemoryStore {
34 pub fn new() -> Self {
36 Self {
37 data: RwLock::new(HashMap::new()),
38 transient: RwLock::new(HashMap::new()),
39 capacity: None,
40 access_order: RwLock::new(Vec::new()),
41 durable_keys: RwLock::new(HashSet::new()),
42 }
43 }
44
45 pub fn bounded(capacity: usize) -> Self {
52 Self {
53 data: RwLock::new(HashMap::new()),
54 transient: RwLock::new(HashMap::new()),
55 capacity: Some(capacity),
56 access_order: RwLock::new(Vec::new()),
57 durable_keys: RwLock::new(HashSet::new()),
58 }
59 }
60
61 async fn write_inner(&self, ck: String, value: serde_json::Value, is_durable: bool) {
69 let mut data = self.data.write().await;
70 let mut order = self.access_order.write().await;
71 let mut durable = self.durable_keys.write().await;
72
73 if is_durable {
74 durable.insert(ck.clone());
75 }
76
77 order.retain(|k| k != &ck);
79 order.push(ck.clone());
80 data.insert(ck, value);
81
82 if let Some(cap) = self.capacity {
84 while data.len() > cap {
85 let evict_idx = order.iter().position(|k| !durable.contains(k));
87 match evict_idx {
88 Some(idx) => {
89 let evict_ck = order.remove(idx);
90 data.remove(&evict_ck);
91 }
92 None => break,
94 }
95 }
96 }
97 }
98}
99
100impl Default for MemoryStore {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106fn composite_key(scope: &Scope, key: &str) -> String {
108 let scope_str = serde_json::to_string(scope).unwrap_or_else(|_| "unknown".to_string());
109 format!("{scope_str}\0{key}")
110}
111
112fn extract_key<'a>(composite: &'a str, scope_prefix: &str) -> Option<&'a str> {
114 composite
115 .strip_prefix(scope_prefix)
116 .and_then(|rest| rest.strip_prefix('\0'))
117}
118
119#[async_trait]
120impl StateStore for MemoryStore {
121 async fn read(
122 &self,
123 scope: &Scope,
124 key: &str,
125 ) -> Result<Option<serde_json::Value>, StateError> {
126 let ck = composite_key(scope, key);
127 let value = self.data.read().await.get(&ck).cloned();
130 if value.is_some() {
131 let mut order = self.access_order.write().await;
132 order.retain(|k| k != &ck);
133 order.push(ck);
134 }
135 Ok(value)
136 }
137
138 async fn write(
139 &self,
140 scope: &Scope,
141 key: &str,
142 value: serde_json::Value,
143 ) -> Result<(), StateError> {
144 let ck = composite_key(scope, key);
145 self.write_inner(ck, value, false).await;
146 Ok(())
147 }
148
149 async fn delete(&self, scope: &Scope, key: &str) -> Result<(), StateError> {
150 let ck = composite_key(scope, key);
151 self.data.write().await.remove(&ck);
152 self.access_order.write().await.retain(|k| k != &ck);
153 self.durable_keys.write().await.remove(&ck);
154 Ok(())
155 }
156
157 async fn list(&self, scope: &Scope, prefix: &str) -> Result<Vec<String>, StateError> {
158 let scope_prefix = serde_json::to_string(scope).unwrap_or_else(|_| "unknown".to_string());
159 let data = self.data.read().await;
160 let keys: Vec<String> = data
161 .keys()
162 .filter_map(|ck| {
163 extract_key(ck, &scope_prefix).and_then(|k| {
164 if k.starts_with(prefix) {
165 Some(k.to_string())
166 } else {
167 None
168 }
169 })
170 })
171 .collect();
172 Ok(keys)
173 }
174
175 async fn search(
176 &self,
177 scope: &Scope,
178 query: &str,
179 limit: usize,
180 ) -> Result<Vec<SearchResult>, StateError> {
181 if query.is_empty() || limit == 0 {
182 return Ok(vec![]);
183 }
184
185 let scope_prefix = serde_json::to_string(scope).unwrap_or_else(|_| "unknown".to_string());
186 let query_lower = query.to_lowercase();
187
188 let data = self.data.read().await;
189 let mut results: Vec<SearchResult> = data
190 .iter()
191 .filter_map(|(ck, value)| {
192 let key = extract_key(ck, &scope_prefix)?;
193 let text = value.to_string();
194 let text_lower = text.to_lowercase();
195
196 let count = text_lower.matches(query_lower.as_str()).count();
197 if count == 0 {
198 return None;
199 }
200
201 let score = count as f64 / text_lower.len().max(1) as f64;
203 let mut result = SearchResult::new(key, score);
204 result.snippet = Some(if text.len() > 200 {
205 format!("{}...", &text[..200])
206 } else {
207 text
208 });
209 Some(result)
210 })
211 .collect();
212
213 results.sort_by(|a, b| {
214 b.score
215 .partial_cmp(&a.score)
216 .unwrap_or(std::cmp::Ordering::Equal)
217 });
218 results.truncate(limit);
219 Ok(results)
220 }
221
222 async fn write_hinted(
223 &self,
224 scope: &Scope,
225 key: &str,
226 value: serde_json::Value,
227 options: &StoreOptions,
228 ) -> Result<(), StateError> {
229 use layer0::state::Lifetime;
230 match options.lifetime {
231 Some(Lifetime::Transient) => {
232 let ck = composite_key(scope, key);
233 self.transient.write().await.insert(ck, value);
234 }
235 Some(Lifetime::Durable) => {
236 let ck = composite_key(scope, key);
237 self.write_inner(ck, value, true).await;
238 }
239 _ => {
240 self.write(scope, key, value).await?;
241 }
242 }
243 Ok(())
244 }
245
246 fn clear_transient(&self) {
247 if let Ok(mut t) = self.transient.try_write() {
249 t.clear();
250 }
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use serde_json::json;
258
259 #[tokio::test]
260 async fn write_and_read() {
261 let store = MemoryStore::new();
262 let scope = Scope::Global;
263
264 store.write(&scope, "key1", json!("value1")).await.unwrap();
265 let val = store.read(&scope, "key1").await.unwrap();
266 assert_eq!(val, Some(json!("value1")));
267 }
268
269 #[tokio::test]
270 async fn read_nonexistent_returns_none() {
271 let store = MemoryStore::new();
272 let scope = Scope::Global;
273
274 let val = store.read(&scope, "missing").await.unwrap();
275 assert_eq!(val, None);
276 }
277
278 #[tokio::test]
279 async fn write_overwrites_existing() {
280 let store = MemoryStore::new();
281 let scope = Scope::Global;
282
283 store.write(&scope, "key1", json!("first")).await.unwrap();
284 store.write(&scope, "key1", json!("second")).await.unwrap();
285 let val = store.read(&scope, "key1").await.unwrap();
286 assert_eq!(val, Some(json!("second")));
287 }
288
289 #[tokio::test]
290 async fn delete_removes_key() {
291 let store = MemoryStore::new();
292 let scope = Scope::Global;
293
294 store.write(&scope, "key1", json!("value1")).await.unwrap();
295 store.delete(&scope, "key1").await.unwrap();
296 let val = store.read(&scope, "key1").await.unwrap();
297 assert_eq!(val, None);
298 }
299
300 #[tokio::test]
301 async fn delete_nonexistent_is_ok() {
302 let store = MemoryStore::new();
303 let scope = Scope::Global;
304
305 let result = store.delete(&scope, "missing").await;
306 assert!(result.is_ok());
307 }
308
309 #[tokio::test]
310 async fn list_keys_with_prefix() {
311 let store = MemoryStore::new();
312 let scope = Scope::Global;
313
314 store
315 .write(&scope, "user:name", json!("Alice"))
316 .await
317 .unwrap();
318 store.write(&scope, "user:age", json!(30)).await.unwrap();
319 store
320 .write(&scope, "system:version", json!("1.0"))
321 .await
322 .unwrap();
323
324 let mut keys = store.list(&scope, "user:").await.unwrap();
325 keys.sort();
326 assert_eq!(keys, vec!["user:age", "user:name"]);
327 }
328
329 #[tokio::test]
330 async fn list_empty_prefix_returns_all() {
331 let store = MemoryStore::new();
332 let scope = Scope::Global;
333
334 store.write(&scope, "a", json!(1)).await.unwrap();
335 store.write(&scope, "b", json!(2)).await.unwrap();
336
337 let keys = store.list(&scope, "").await.unwrap();
338 assert_eq!(keys.len(), 2);
339 }
340
341 #[tokio::test]
342 async fn scopes_are_isolated() {
343 let store = MemoryStore::new();
344 let global = Scope::Global;
345 let session = Scope::Session(layer0::SessionId::new("s1"));
346
347 store
348 .write(&global, "key", json!("global_val"))
349 .await
350 .unwrap();
351 store
352 .write(&session, "key", json!("session_val"))
353 .await
354 .unwrap();
355
356 let global_val = store.read(&global, "key").await.unwrap();
357 let session_val = store.read(&session, "key").await.unwrap();
358
359 assert_eq!(global_val, Some(json!("global_val")));
360 assert_eq!(session_val, Some(json!("session_val")));
361 }
362
363 #[tokio::test]
364 async fn search_returns_empty_on_no_match() {
365 let store = MemoryStore::new();
366 let scope = Scope::Global;
367
368 store
369 .write(&scope, "k1", json!("hello world"))
370 .await
371 .unwrap();
372 let results = store.search(&scope, "xyzzy", 10).await.unwrap();
373 assert!(results.is_empty());
374 }
375
376 #[test]
377 fn default_store_is_empty() {
378 let store = MemoryStore::default();
379 let _ = store; }
381
382 #[test]
383 fn memory_store_implements_state_store() {
384 fn _assert_state_store<T: StateStore>() {}
385 _assert_state_store::<MemoryStore>();
386 }
387
388 #[tokio::test]
389 async fn test_transient_write_not_durable() {
390 use layer0::state::{Lifetime, StoreOptions};
391
392 let store = MemoryStore::new();
393 let scope = Scope::Global;
394
395 let opts = StoreOptions {
397 lifetime: Some(Lifetime::Transient),
398 ..Default::default()
399 };
400 store
401 .write_hinted(&scope, "scratch", serde_json::json!("temp"), &opts)
402 .await
403 .unwrap();
404
405 let val = store.read(&scope, "scratch").await.unwrap();
407 assert_eq!(val, None, "transient entry must not be visible via read()");
408
409 store.clear_transient();
411 store.clear_transient();
412
413 store
415 .write(&scope, "durable", serde_json::json!("persisted"))
416 .await
417 .unwrap();
418
419 store.clear_transient();
421
422 let durable_val = store.read(&scope, "durable").await.unwrap();
423 assert_eq!(
424 durable_val,
425 Some(serde_json::json!("persisted")),
426 "durable entry must survive clear_transient()"
427 );
428 }
429
430 #[tokio::test]
433 async fn bounded_evicts_oldest() {
434 let store = MemoryStore::bounded(3);
435 let scope = Scope::Global;
436
437 for k in ["a", "b", "c", "d", "e"] {
438 store.write(&scope, k, json!(k)).await.unwrap();
439 }
440
441 assert_eq!(
442 store.read(&scope, "a").await.unwrap(),
443 None,
444 "a should be evicted"
445 );
446 assert_eq!(
447 store.read(&scope, "b").await.unwrap(),
448 None,
449 "b should be evicted"
450 );
451 assert_eq!(store.read(&scope, "c").await.unwrap(), Some(json!("c")));
452 assert_eq!(store.read(&scope, "d").await.unwrap(), Some(json!("d")));
453 assert_eq!(store.read(&scope, "e").await.unwrap(), Some(json!("e")));
454 }
455
456 #[tokio::test]
457 async fn bounded_read_refreshes_lru() {
458 let store = MemoryStore::bounded(3);
459 let scope = Scope::Global;
460
461 store.write(&scope, "a", json!("a")).await.unwrap();
462 store.write(&scope, "b", json!("b")).await.unwrap();
463 store.write(&scope, "c", json!("c")).await.unwrap();
464
465 let _ = store.read(&scope, "a").await.unwrap();
467
468 store.write(&scope, "d", json!("d")).await.unwrap();
470
471 assert_eq!(
472 store.read(&scope, "b").await.unwrap(),
473 None,
474 "b should be evicted"
475 );
476 assert!(
477 store.read(&scope, "a").await.unwrap().is_some(),
478 "a should survive"
479 );
480 assert!(
481 store.read(&scope, "c").await.unwrap().is_some(),
482 "c should survive"
483 );
484 assert!(
485 store.read(&scope, "d").await.unwrap().is_some(),
486 "d should survive"
487 );
488 }
489
490 #[tokio::test]
491 async fn bounded_unlimited_default() {
492 let store = MemoryStore::new();
493 let scope = Scope::Global;
494
495 for i in 0..100u32 {
496 store.write(&scope, &i.to_string(), json!(i)).await.unwrap();
497 }
498
499 for i in 0..100u32 {
500 assert!(
501 store.read(&scope, &i.to_string()).await.unwrap().is_some(),
502 "key {i} should not be evicted from unbounded store",
503 );
504 }
505 }
506
507 #[tokio::test]
510 async fn search_finds_substring() {
511 let store = MemoryStore::new();
512 let scope = Scope::Global;
513
514 store
515 .write(&scope, "k1", json!("hello world"))
516 .await
517 .unwrap();
518 store
519 .write(&scope, "k2", json!("goodbye world"))
520 .await
521 .unwrap();
522 store.write(&scope, "k3", json!(42)).await.unwrap();
523
524 let results = store.search(&scope, "world", 10).await.unwrap();
525 let keys: Vec<&str> = results.iter().map(|r| r.key.as_str()).collect();
526 assert!(keys.contains(&"k1"), "k1 should match");
527 assert!(keys.contains(&"k2"), "k2 should match");
528 assert!(!keys.contains(&"k3"), "k3 should not match");
529 }
530
531 #[tokio::test]
532 async fn search_case_insensitive() {
533 let store = MemoryStore::new();
534 let scope = Scope::Global;
535
536 store
537 .write(&scope, "k1", json!("Hello World"))
538 .await
539 .unwrap();
540 store.write(&scope, "k2", json!("HELLO")).await.unwrap();
541 store.write(&scope, "k3", json!("unrelated")).await.unwrap();
542
543 let results = store.search(&scope, "hello", 10).await.unwrap();
544 let keys: Vec<&str> = results.iter().map(|r| r.key.as_str()).collect();
545 assert!(keys.contains(&"k1"), "k1 should match case-insensitively");
546 assert!(keys.contains(&"k2"), "k2 should match case-insensitively");
547 assert!(!keys.contains(&"k3"), "k3 should not match");
548 }
549
550 #[tokio::test]
551 async fn search_respects_limit() {
552 let store = MemoryStore::new();
553 let scope = Scope::Global;
554
555 for i in 0..10u32 {
556 store
557 .write(&scope, &format!("k{i}"), json!("needle in haystack"))
558 .await
559 .unwrap();
560 }
561
562 let results = store.search(&scope, "needle", 3).await.unwrap();
563 assert_eq!(results.len(), 3, "results must be capped at the limit");
564 }
565}