1use crate::error::Result;
15use crate::value::Value;
16use std::collections::HashMap;
17use std::sync::{Arc, Mutex};
18
19pub trait StateStore: Send + Sync {
25 fn get(&self, node_id: &str) -> Result<Option<Arc<Value>>>;
27
28 fn set(&self, node_id: &str, state: Value) -> Result<()>;
30
31 fn remove(&self, node_id: &str) -> Result<()>;
33
34 fn clear(&self) -> Result<()>;
36
37 fn keys(&self) -> Result<Vec<String>>;
39}
40
41#[derive(Default)]
46pub struct MemoryStateStore {
47 inner: Mutex<HashMap<String, Arc<Value>>>,
48}
49
50impl MemoryStateStore {
51 pub fn new() -> Self {
52 Self::default()
53 }
54}
55
56impl StateStore for MemoryStateStore {
57 fn get(&self, node_id: &str) -> Result<Option<Arc<Value>>> {
58 let guard = self.inner.lock().expect("MemoryStateStore poisoned");
59 Ok(guard.get(node_id).cloned())
60 }
61
62 fn set(&self, node_id: &str, state: Value) -> Result<()> {
63 let mut guard = self.inner.lock().expect("MemoryStateStore poisoned");
64 guard.insert(node_id.to_string(), Arc::new(state));
65 Ok(())
66 }
67
68 fn remove(&self, node_id: &str) -> Result<()> {
69 let mut guard = self.inner.lock().expect("MemoryStateStore poisoned");
70 guard.remove(node_id);
71 Ok(())
72 }
73
74 fn clear(&self) -> Result<()> {
75 let mut guard = self.inner.lock().expect("MemoryStateStore poisoned");
76 guard.clear();
77 Ok(())
78 }
79
80 fn keys(&self) -> Result<Vec<String>> {
81 let guard = self.inner.lock().expect("MemoryStateStore poisoned");
82 Ok(guard.keys().cloned().collect())
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89
90 #[test]
91 fn memory_store_roundtrip() {
92 let store = MemoryStateStore::new();
93 assert!(store.get("a").unwrap().is_none());
94
95 store
96 .set("a", Value::Json(serde_json::json!({"mean": 5.0})))
97 .unwrap();
98 let state = store.get("a").unwrap().unwrap();
99 assert_eq!(state.as_json().unwrap()["mean"], 5.0);
100
101 let s1 = store.get("a").unwrap().unwrap();
103 let s2 = store.get("a").unwrap().unwrap();
104 assert!(Arc::ptr_eq(&s1, &s2));
105 }
106
107 #[test]
108 fn memory_store_remove_and_clear() {
109 let store = MemoryStateStore::new();
110 store.set("a", Value::Empty).unwrap();
111 store.set("b", Value::Empty).unwrap();
112 assert_eq!(store.keys().unwrap().len(), 2);
113
114 store.remove("a").unwrap();
115 assert!(store.get("a").unwrap().is_none());
116 assert!(store.get("b").unwrap().is_some());
117
118 store.clear().unwrap();
119 assert!(store.keys().unwrap().is_empty());
120 }
121
122 #[test]
123 fn memory_store_overwrites() {
124 let store = MemoryStateStore::new();
125 store
126 .set("a", Value::Json(serde_json::json!({"v": 1})))
127 .unwrap();
128 store
129 .set("a", Value::Json(serde_json::json!({"v": 2})))
130 .unwrap();
131 let state = store.get("a").unwrap().unwrap();
132 assert_eq!(state.as_json().unwrap()["v"], 2);
133 }
134}