Skip to main content

sage_runtime/persistence/
mod.rs

1//! Persistence support for @persistent agent beliefs.
2//!
3//! This module provides the runtime support for persistent agent state:
4//! - `CheckpointStore` trait for storage backends
5//! - `Persisted<T>` wrapper for auto-checkpointing fields
6//! - `AgentCheckpoint` for managing agent-level persistence
7//!
8//! # Backends
9//!
10//! The following backends are available via feature flags:
11//! - `persistence-sqlite`: SQLite database (recommended for local development)
12//! - `persistence-postgres`: PostgreSQL (recommended for production)
13//! - `persistence-file`: JSON files (useful for debugging)
14//!
15//! Without any persistence feature, only `MemoryCheckpointStore` is available.
16
17// Sync adapters for async persistence backends (native only)
18#[cfg(all(
19    not(target_arch = "wasm32"),
20    any(
21        feature = "persistence-sqlite",
22        feature = "persistence-postgres",
23        feature = "persistence-file"
24    )
25))]
26mod backends;
27
28#[cfg(all(not(target_arch = "wasm32"), feature = "persistence-file"))]
29pub use backends::SyncFileStore;
30#[cfg(all(not(target_arch = "wasm32"), feature = "persistence-postgres"))]
31pub use backends::SyncPostgresStore;
32#[cfg(all(not(target_arch = "wasm32"), feature = "persistence-sqlite"))]
33pub use backends::SyncSqliteStore;
34
35use serde::{de::DeserializeOwned, Serialize};
36use std::collections::HashMap;
37use std::sync::{Arc, RwLock};
38
39/// A checkpoint store for persisting agent state.
40///
41/// This is a re-export of the trait from sage-persistence, simplified
42/// for use in generated code.
43pub trait CheckpointStore: Send + Sync {
44    /// Save a field value synchronously (blocks on async).
45    fn save_sync(&self, agent_key: &str, field: &str, value: serde_json::Value);
46
47    /// Load a field value synchronously.
48    fn load_sync(&self, agent_key: &str, field: &str) -> Option<serde_json::Value>;
49
50    /// Load all fields for an agent.
51    fn load_all_sync(&self, agent_key: &str) -> HashMap<String, serde_json::Value>;
52
53    /// Save all fields atomically.
54    fn save_all_sync(&self, agent_key: &str, fields: &HashMap<String, serde_json::Value>);
55
56    /// Check if any checkpoint exists for an agent.
57    fn exists_sync(&self, agent_key: &str) -> bool;
58}
59
60/// In-memory checkpoint store for testing.
61#[derive(Default)]
62pub struct MemoryCheckpointStore {
63    data: RwLock<HashMap<String, HashMap<String, serde_json::Value>>>,
64}
65
66impl MemoryCheckpointStore {
67    pub fn new() -> Self {
68        Self::default()
69    }
70}
71
72impl CheckpointStore for MemoryCheckpointStore {
73    fn save_sync(&self, agent_key: &str, field: &str, value: serde_json::Value) {
74        let mut data = self.data.write().unwrap();
75        data.entry(agent_key.to_string())
76            .or_default()
77            .insert(field.to_string(), value);
78    }
79
80    fn load_sync(&self, agent_key: &str, field: &str) -> Option<serde_json::Value> {
81        self.data
82            .read()
83            .unwrap()
84            .get(agent_key)
85            .and_then(|fields| fields.get(field).cloned())
86    }
87
88    fn load_all_sync(&self, agent_key: &str) -> HashMap<String, serde_json::Value> {
89        self.data
90            .read()
91            .unwrap()
92            .get(agent_key)
93            .cloned()
94            .unwrap_or_default()
95    }
96
97    fn save_all_sync(&self, agent_key: &str, fields: &HashMap<String, serde_json::Value>) {
98        let mut data = self.data.write().unwrap();
99        data.insert(agent_key.to_string(), fields.clone());
100    }
101
102    fn exists_sync(&self, agent_key: &str) -> bool {
103        self.data.read().unwrap().contains_key(agent_key)
104    }
105}
106
107/// A wrapper for @persistent fields that auto-checkpoints on modification.
108///
109/// This provides interior mutability and automatic persistence when the
110/// value is modified via `set()`.
111pub struct Persisted<T> {
112    value: RwLock<T>,
113    store: Arc<dyn CheckpointStore>,
114    agent_key: String,
115    field_name: String,
116}
117
118impl<T: Clone + Serialize + DeserializeOwned + Default + Send> Persisted<T> {
119    /// Create a new persisted field, loading from checkpoint if available.
120    pub fn new(
121        store: Arc<dyn CheckpointStore>,
122        agent_key: impl Into<String>,
123        field_name: impl Into<String>,
124    ) -> Self {
125        let agent_key = agent_key.into();
126        let field_name = field_name.into();
127
128        // Try to load from checkpoint
129        let value = store
130            .load_sync(&agent_key, &field_name)
131            .and_then(|v| serde_json::from_value(v).ok())
132            .unwrap_or_default();
133
134        Self {
135            value: RwLock::new(value),
136            store,
137            agent_key,
138            field_name,
139        }
140    }
141
142    /// Create with an explicit initial value (used when no checkpoint exists).
143    pub fn with_initial(
144        store: Arc<dyn CheckpointStore>,
145        agent_key: impl Into<String>,
146        field_name: impl Into<String>,
147        initial: T,
148    ) -> Self {
149        let agent_key = agent_key.into();
150        let field_name = field_name.into();
151
152        // Try to load from checkpoint, fall back to initial
153        let value = store
154            .load_sync(&agent_key, &field_name)
155            .and_then(|v| serde_json::from_value(v).ok())
156            .unwrap_or(initial);
157
158        Self {
159            value: RwLock::new(value),
160            store,
161            agent_key,
162            field_name,
163        }
164    }
165
166    /// Get the current value.
167    pub fn get(&self) -> T {
168        self.value.read().unwrap().clone()
169    }
170
171    /// Set the value and checkpoint it.
172    pub fn set(&self, new_value: T) {
173        *self.value.write().unwrap() = new_value.clone();
174        if let Ok(json) = serde_json::to_value(&new_value) {
175            self.store
176                .save_sync(&self.agent_key, &self.field_name, json);
177        }
178    }
179
180    /// Checkpoint the current value without modifying it.
181    pub fn checkpoint(&self) {
182        let value = self.value.read().unwrap().clone();
183        if let Ok(json) = serde_json::to_value(&value) {
184            self.store
185                .save_sync(&self.agent_key, &self.field_name, json);
186        }
187    }
188}
189
190/// Helper to generate a unique checkpoint key for an agent instance.
191pub fn agent_checkpoint_key(agent_name: &str, beliefs: &serde_json::Value) -> String {
192    use std::collections::hash_map::DefaultHasher;
193    use std::hash::{Hash, Hasher};
194
195    let mut hasher = DefaultHasher::new();
196    agent_name.hash(&mut hasher);
197    beliefs.to_string().hash(&mut hasher);
198    format!("{}_{:016x}", agent_name, hasher.finish())
199}
200
201/// Helper to save all @persistent fields atomically before yield.
202pub fn checkpoint_all<S: CheckpointStore + ?Sized>(
203    store: &S,
204    agent_key: &str,
205    fields: Vec<(&str, serde_json::Value)>,
206) {
207    let map: HashMap<String, serde_json::Value> = fields
208        .into_iter()
209        .map(|(k, v)| (k.to_string(), v))
210        .collect();
211    store.save_all_sync(agent_key, &map);
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    fn make_store() -> Arc<dyn CheckpointStore> {
219        Arc::new(MemoryCheckpointStore::new())
220    }
221
222    #[test]
223    fn memory_store_save_load() {
224        let store = MemoryCheckpointStore::new();
225        store.save_sync("agent1", "count", serde_json::json!(42));
226
227        let loaded = store.load_sync("agent1", "count");
228        assert_eq!(loaded, Some(serde_json::json!(42)));
229    }
230
231    #[test]
232    fn persisted_field_loads_from_checkpoint() {
233        let store = make_store();
234        store.save_sync("agent1", "count", serde_json::json!(100));
235
236        let field: Persisted<i64> = Persisted::new(store, "agent1", "count");
237        assert_eq!(field.get(), 100);
238    }
239
240    #[test]
241    fn persisted_field_defaults_when_no_checkpoint() {
242        let store = make_store();
243        let field: Persisted<i64> = Persisted::new(store, "agent1", "count");
244        assert_eq!(field.get(), 0); // Default for i64
245    }
246
247    #[test]
248    fn persisted_field_auto_checkpoints_on_set() {
249        let store = make_store();
250        let field: Persisted<i64> = Persisted::new(Arc::clone(&store), "agent1", "count");
251
252        field.set(42);
253
254        // Verify it was persisted
255        let loaded = store.load_sync("agent1", "count");
256        assert_eq!(loaded, Some(serde_json::json!(42)));
257    }
258
259    #[test]
260    fn checkpoint_key_varies_with_beliefs() {
261        let key1 = agent_checkpoint_key("Agent", &serde_json::json!({"x": 1}));
262        let key2 = agent_checkpoint_key("Agent", &serde_json::json!({"x": 2}));
263        assert_ne!(key1, key2);
264    }
265}