Skip to main content

somatize_core/
state.rs

1//! Trained-state storage — authoritative data produced by `fit()`.
2//!
3//! States are distinct from [`CacheStore`](crate::CacheStore) entries:
4//! - Cache entries are **discardable** — the system can recompute them.
5//! - States are **authoritative** — they are the product of training and
6//!   belong to the Graph that produced them. They must not be evicted
7//!   arbitrarily.
8//!
9//! [`StateStore`] is the trait; implementations may keep states in memory,
10//! on local disk, or in object storage. States are returned as
11//! `Arc<Value>` so the hot forward path can borrow them (`&*arc`) without
12//! cloning potentially-large tensors.
13
14use crate::error::Result;
15use crate::value::Value;
16use std::collections::HashMap;
17use std::sync::{Arc, Mutex};
18
19/// Storage for trained filter states, keyed by node id.
20///
21/// Implementations must be `Send + Sync` and use interior mutability so
22/// the store can be shared (via `Arc`) across the executor and the
23/// graph session.
24pub trait StateStore: Send + Sync {
25    /// Fetch the state for `node_id`, if present.
26    fn get(&self, node_id: &str) -> Result<Option<Arc<Value>>>;
27
28    /// Store `state` under `node_id`, replacing any previous value.
29    fn set(&self, node_id: &str, state: Value) -> Result<()>;
30
31    /// Remove the state for `node_id`, if present.
32    fn remove(&self, node_id: &str) -> Result<()>;
33
34    /// Drop all stored states.
35    fn clear(&self) -> Result<()>;
36
37    /// List all node ids that currently have a stored state.
38    fn keys(&self) -> Result<Vec<String>>;
39}
40
41/// In-memory [`StateStore`] — the default backend.
42///
43/// States live as `Arc<Value>` so reads are zero-copy (just `Arc::clone`)
44/// and multiple consumers can hold references concurrently.
45#[derive(Default)]
46pub struct MemoryStateStore {
47    inner: Mutex<HashMap<String, Arc<Value>>>,
48}
49
50impl MemoryStateStore {
51    pub fn new() -> Self {
52        Self::default()
53    }
54}
55
56impl StateStore for MemoryStateStore {
57    fn get(&self, node_id: &str) -> Result<Option<Arc<Value>>> {
58        let guard = self.inner.lock().expect("MemoryStateStore poisoned");
59        Ok(guard.get(node_id).cloned())
60    }
61
62    fn set(&self, node_id: &str, state: Value) -> Result<()> {
63        let mut guard = self.inner.lock().expect("MemoryStateStore poisoned");
64        guard.insert(node_id.to_string(), Arc::new(state));
65        Ok(())
66    }
67
68    fn remove(&self, node_id: &str) -> Result<()> {
69        let mut guard = self.inner.lock().expect("MemoryStateStore poisoned");
70        guard.remove(node_id);
71        Ok(())
72    }
73
74    fn clear(&self) -> Result<()> {
75        let mut guard = self.inner.lock().expect("MemoryStateStore poisoned");
76        guard.clear();
77        Ok(())
78    }
79
80    fn keys(&self) -> Result<Vec<String>> {
81        let guard = self.inner.lock().expect("MemoryStateStore poisoned");
82        Ok(guard.keys().cloned().collect())
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn memory_store_roundtrip() {
92        let store = MemoryStateStore::new();
93        assert!(store.get("a").unwrap().is_none());
94
95        store
96            .set("a", Value::json(serde_json::json!({"mean": 5.0})))
97            .unwrap();
98        let state = store.get("a").unwrap().unwrap();
99        assert_eq!(state.as_json().unwrap()["mean"], 5.0);
100
101        // Same Arc returned on subsequent reads
102        let s1 = store.get("a").unwrap().unwrap();
103        let s2 = store.get("a").unwrap().unwrap();
104        assert!(Arc::ptr_eq(&s1, &s2));
105    }
106
107    #[test]
108    fn memory_store_remove_and_clear() {
109        let store = MemoryStateStore::new();
110        store.set("a", Value::Empty).unwrap();
111        store.set("b", Value::Empty).unwrap();
112        assert_eq!(store.keys().unwrap().len(), 2);
113
114        store.remove("a").unwrap();
115        assert!(store.get("a").unwrap().is_none());
116        assert!(store.get("b").unwrap().is_some());
117
118        store.clear().unwrap();
119        assert!(store.keys().unwrap().is_empty());
120    }
121
122    #[test]
123    fn memory_store_overwrites() {
124        let store = MemoryStateStore::new();
125        store
126            .set("a", Value::json(serde_json::json!({"v": 1})))
127            .unwrap();
128        store
129            .set("a", Value::json(serde_json::json!({"v": 2})))
130            .unwrap();
131        let state = store.get("a").unwrap().unwrap();
132        assert_eq!(state.as_json().unwrap()["v"], 2);
133    }
134}