sage_runtime/persistence/
mod.rs1#[cfg(all(
19 not(target_arch = "wasm32"),
20 any(
21 feature = "persistence-sqlite",
22 feature = "persistence-postgres",
23 feature = "persistence-file"
24 )
25))]
26mod backends;
27
28#[cfg(all(not(target_arch = "wasm32"), feature = "persistence-file"))]
29pub use backends::SyncFileStore;
30#[cfg(all(not(target_arch = "wasm32"), feature = "persistence-postgres"))]
31pub use backends::SyncPostgresStore;
32#[cfg(all(not(target_arch = "wasm32"), feature = "persistence-sqlite"))]
33pub use backends::SyncSqliteStore;
34
35use serde::{de::DeserializeOwned, Serialize};
36use std::collections::HashMap;
37use std::sync::{Arc, RwLock};
38
39pub trait CheckpointStore: Send + Sync {
44 fn save_sync(&self, agent_key: &str, field: &str, value: serde_json::Value);
46
47 fn load_sync(&self, agent_key: &str, field: &str) -> Option<serde_json::Value>;
49
50 fn load_all_sync(&self, agent_key: &str) -> HashMap<String, serde_json::Value>;
52
53 fn save_all_sync(&self, agent_key: &str, fields: &HashMap<String, serde_json::Value>);
55
56 fn exists_sync(&self, agent_key: &str) -> bool;
58}
59
60#[derive(Default)]
62pub struct MemoryCheckpointStore {
63 data: RwLock<HashMap<String, HashMap<String, serde_json::Value>>>,
64}
65
66impl MemoryCheckpointStore {
67 pub fn new() -> Self {
68 Self::default()
69 }
70}
71
72impl CheckpointStore for MemoryCheckpointStore {
73 fn save_sync(&self, agent_key: &str, field: &str, value: serde_json::Value) {
74 let mut data = self.data.write().unwrap();
75 data.entry(agent_key.to_string())
76 .or_default()
77 .insert(field.to_string(), value);
78 }
79
80 fn load_sync(&self, agent_key: &str, field: &str) -> Option<serde_json::Value> {
81 self.data
82 .read()
83 .unwrap()
84 .get(agent_key)
85 .and_then(|fields| fields.get(field).cloned())
86 }
87
88 fn load_all_sync(&self, agent_key: &str) -> HashMap<String, serde_json::Value> {
89 self.data
90 .read()
91 .unwrap()
92 .get(agent_key)
93 .cloned()
94 .unwrap_or_default()
95 }
96
97 fn save_all_sync(&self, agent_key: &str, fields: &HashMap<String, serde_json::Value>) {
98 let mut data = self.data.write().unwrap();
99 data.insert(agent_key.to_string(), fields.clone());
100 }
101
102 fn exists_sync(&self, agent_key: &str) -> bool {
103 self.data.read().unwrap().contains_key(agent_key)
104 }
105}
106
107pub struct Persisted<T> {
112 value: RwLock<T>,
113 store: Arc<dyn CheckpointStore>,
114 agent_key: String,
115 field_name: String,
116}
117
118impl<T: Clone + Serialize + DeserializeOwned + Default + Send> Persisted<T> {
119 pub fn new(
121 store: Arc<dyn CheckpointStore>,
122 agent_key: impl Into<String>,
123 field_name: impl Into<String>,
124 ) -> Self {
125 let agent_key = agent_key.into();
126 let field_name = field_name.into();
127
128 let value = store
130 .load_sync(&agent_key, &field_name)
131 .and_then(|v| serde_json::from_value(v).ok())
132 .unwrap_or_default();
133
134 Self {
135 value: RwLock::new(value),
136 store,
137 agent_key,
138 field_name,
139 }
140 }
141
142 pub fn with_initial(
144 store: Arc<dyn CheckpointStore>,
145 agent_key: impl Into<String>,
146 field_name: impl Into<String>,
147 initial: T,
148 ) -> Self {
149 let agent_key = agent_key.into();
150 let field_name = field_name.into();
151
152 let value = store
154 .load_sync(&agent_key, &field_name)
155 .and_then(|v| serde_json::from_value(v).ok())
156 .unwrap_or(initial);
157
158 Self {
159 value: RwLock::new(value),
160 store,
161 agent_key,
162 field_name,
163 }
164 }
165
166 pub fn get(&self) -> T {
168 self.value.read().unwrap().clone()
169 }
170
171 pub fn set(&self, new_value: T) {
173 *self.value.write().unwrap() = new_value.clone();
174 if let Ok(json) = serde_json::to_value(&new_value) {
175 self.store
176 .save_sync(&self.agent_key, &self.field_name, json);
177 }
178 }
179
180 pub fn checkpoint(&self) {
182 let value = self.value.read().unwrap().clone();
183 if let Ok(json) = serde_json::to_value(&value) {
184 self.store
185 .save_sync(&self.agent_key, &self.field_name, json);
186 }
187 }
188}
189
190pub fn agent_checkpoint_key(agent_name: &str, beliefs: &serde_json::Value) -> String {
192 use std::collections::hash_map::DefaultHasher;
193 use std::hash::{Hash, Hasher};
194
195 let mut hasher = DefaultHasher::new();
196 agent_name.hash(&mut hasher);
197 beliefs.to_string().hash(&mut hasher);
198 format!("{}_{:016x}", agent_name, hasher.finish())
199}
200
201pub fn checkpoint_all<S: CheckpointStore + ?Sized>(
203 store: &S,
204 agent_key: &str,
205 fields: Vec<(&str, serde_json::Value)>,
206) {
207 let map: HashMap<String, serde_json::Value> = fields
208 .into_iter()
209 .map(|(k, v)| (k.to_string(), v))
210 .collect();
211 store.save_all_sync(agent_key, &map);
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 fn make_store() -> Arc<dyn CheckpointStore> {
219 Arc::new(MemoryCheckpointStore::new())
220 }
221
222 #[test]
223 fn memory_store_save_load() {
224 let store = MemoryCheckpointStore::new();
225 store.save_sync("agent1", "count", serde_json::json!(42));
226
227 let loaded = store.load_sync("agent1", "count");
228 assert_eq!(loaded, Some(serde_json::json!(42)));
229 }
230
231 #[test]
232 fn persisted_field_loads_from_checkpoint() {
233 let store = make_store();
234 store.save_sync("agent1", "count", serde_json::json!(100));
235
236 let field: Persisted<i64> = Persisted::new(store, "agent1", "count");
237 assert_eq!(field.get(), 100);
238 }
239
240 #[test]
241 fn persisted_field_defaults_when_no_checkpoint() {
242 let store = make_store();
243 let field: Persisted<i64> = Persisted::new(store, "agent1", "count");
244 assert_eq!(field.get(), 0); }
246
247 #[test]
248 fn persisted_field_auto_checkpoints_on_set() {
249 let store = make_store();
250 let field: Persisted<i64> = Persisted::new(Arc::clone(&store), "agent1", "count");
251
252 field.set(42);
253
254 let loaded = store.load_sync("agent1", "count");
256 assert_eq!(loaded, Some(serde_json::json!(42)));
257 }
258
259 #[test]
260 fn checkpoint_key_varies_with_beliefs() {
261 let key1 = agent_checkpoint_key("Agent", &serde_json::json!({"x": 1}));
262 let key2 = agent_checkpoint_key("Agent", &serde_json::json!({"x": 2}));
263 assert_ne!(key1, key2);
264 }
265}