Skip to main content

serdes_ai_graph/
persistence.rs

1//! State persistence for graph execution.
2
3use crate::error::GraphError;
4use async_trait::async_trait;
5use parking_lot::RwLock;
6use serde::{de::DeserializeOwned, Serialize};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::Arc;
10
11/// Error during persistence operations.
12#[derive(Debug, thiserror::Error)]
13pub enum PersistenceError {
14    /// IO error.
15    #[error("IO error: {0}")]
16    Io(#[from] std::io::Error),
17
18    /// Serialization error.
19    #[error("Serialization error: {0}")]
20    Serialization(#[from] serde_json::Error),
21
22    /// State not found.
23    #[error("State not found for run: {0}")]
24    NotFound(String),
25
26    /// Other error.
27    #[error("{0}")]
28    Other(String),
29}
30
31impl From<PersistenceError> for GraphError {
32    fn from(e: PersistenceError) -> Self {
33        GraphError::Persistence(e.to_string())
34    }
35}
36
37/// Trait for persisting graph state.
38#[async_trait]
39pub trait StatePersistence<State, End>: Send + Sync {
40    /// Save state for a run.
41    async fn save_state(
42        &self,
43        run_id: &str,
44        state: &State,
45        step: u32,
46    ) -> Result<(), PersistenceError>;
47
48    /// Load state for a run.
49    async fn load_state(&self, run_id: &str) -> Result<Option<(State, u32)>, PersistenceError>;
50
51    /// Save the final result.
52    async fn save_result(&self, run_id: &str, result: &End) -> Result<(), PersistenceError>;
53
54    /// Load the final result.
55    async fn load_result(&self, run_id: &str) -> Result<Option<End>, PersistenceError>;
56
57    /// Delete state for a run.
58    async fn delete(&self, run_id: &str) -> Result<(), PersistenceError>;
59
60    /// List all stored run IDs.
61    async fn list_runs(&self) -> Result<Vec<String>, PersistenceError>;
62}
63
64/// In-memory state persistence.
65#[derive(Clone)]
66pub struct InMemoryPersistence<State, End> {
67    states: Arc<RwLock<HashMap<String, (State, u32)>>>,
68    results: Arc<RwLock<HashMap<String, End>>>,
69}
70
71impl<State, End> InMemoryPersistence<State, End> {
72    /// Create a new in-memory persistence store.
73    pub fn new() -> Self {
74        Self {
75            states: Arc::new(RwLock::new(HashMap::new())),
76            results: Arc::new(RwLock::new(HashMap::new())),
77        }
78    }
79
80    /// Clear all stored data.
81    pub fn clear(&self) {
82        self.states.write().clear();
83        self.results.write().clear();
84    }
85
86    /// Get the number of stored states.
87    pub fn state_count(&self) -> usize {
88        self.states.read().len()
89    }
90
91    /// Get the number of stored results.
92    pub fn result_count(&self) -> usize {
93        self.results.read().len()
94    }
95}
96
97impl<State, End> Default for InMemoryPersistence<State, End> {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103#[async_trait]
104impl<State, End> StatePersistence<State, End> for InMemoryPersistence<State, End>
105where
106    State: Clone + Send + Sync + 'static,
107    End: Clone + Send + Sync + 'static,
108{
109    async fn save_state(
110        &self,
111        run_id: &str,
112        state: &State,
113        step: u32,
114    ) -> Result<(), PersistenceError> {
115        self.states
116            .write()
117            .insert(run_id.to_string(), (state.clone(), step));
118        Ok(())
119    }
120
121    async fn load_state(&self, run_id: &str) -> Result<Option<(State, u32)>, PersistenceError> {
122        Ok(self.states.read().get(run_id).cloned())
123    }
124
125    async fn save_result(&self, run_id: &str, result: &End) -> Result<(), PersistenceError> {
126        self.results
127            .write()
128            .insert(run_id.to_string(), result.clone());
129        Ok(())
130    }
131
132    async fn load_result(&self, run_id: &str) -> Result<Option<End>, PersistenceError> {
133        Ok(self.results.read().get(run_id).cloned())
134    }
135
136    async fn delete(&self, run_id: &str) -> Result<(), PersistenceError> {
137        self.states.write().remove(run_id);
138        self.results.write().remove(run_id);
139        Ok(())
140    }
141
142    async fn list_runs(&self) -> Result<Vec<String>, PersistenceError> {
143        let state_keys: std::collections::HashSet<_> = self.states.read().keys().cloned().collect();
144        let result_keys: std::collections::HashSet<_> =
145            self.results.read().keys().cloned().collect();
146        Ok(state_keys.union(&result_keys).cloned().collect())
147    }
148}
149
150/// File-based state persistence.
151pub struct FilePersistence {
152    directory: PathBuf,
153}
154
155impl FilePersistence {
156    /// Create a new file-based persistence store.
157    pub fn new(directory: impl Into<PathBuf>) -> Self {
158        Self {
159            directory: directory.into(),
160        }
161    }
162
163    /// Ensure the directory exists.
164    pub async fn ensure_dir(&self) -> Result<(), PersistenceError> {
165        tokio::fs::create_dir_all(&self.directory).await?;
166        Ok(())
167    }
168
169    fn state_path(&self, run_id: &str) -> PathBuf {
170        self.directory.join(format!("{}_state.json", run_id))
171    }
172
173    fn result_path(&self, run_id: &str) -> PathBuf {
174        self.directory.join(format!("{}_result.json", run_id))
175    }
176}
177
178#[async_trait]
179impl<State, End> StatePersistence<State, End> for FilePersistence
180where
181    State: Serialize + DeserializeOwned + Send + Sync + 'static,
182    End: Serialize + DeserializeOwned + Send + Sync + 'static,
183{
184    async fn save_state(
185        &self,
186        run_id: &str,
187        state: &State,
188        step: u32,
189    ) -> Result<(), PersistenceError> {
190        self.ensure_dir().await?;
191        let path = self.state_path(run_id);
192        let data = serde_json::json!({
193            "state": state,
194            "step": step
195        });
196        let content = serde_json::to_string_pretty(&data)?;
197        tokio::fs::write(&path, content).await?;
198        Ok(())
199    }
200
201    async fn load_state(&self, run_id: &str) -> Result<Option<(State, u32)>, PersistenceError> {
202        let path = self.state_path(run_id);
203        if !path.exists() {
204            return Ok(None);
205        }
206
207        let content = tokio::fs::read_to_string(&path).await?;
208        let value: serde_json::Value = serde_json::from_str(&content)?;
209        let state: State = serde_json::from_value(value["state"].clone())?;
210        let step = value["step"].as_u64().unwrap_or(0) as u32;
211        Ok(Some((state, step)))
212    }
213
214    async fn save_result(&self, run_id: &str, result: &End) -> Result<(), PersistenceError> {
215        self.ensure_dir().await?;
216        let path = self.result_path(run_id);
217        let content = serde_json::to_string_pretty(result)?;
218        tokio::fs::write(&path, content).await?;
219        Ok(())
220    }
221
222    async fn load_result(&self, run_id: &str) -> Result<Option<End>, PersistenceError> {
223        let path = self.result_path(run_id);
224        if !path.exists() {
225            return Ok(None);
226        }
227
228        let content = tokio::fs::read_to_string(&path).await?;
229        let result: End = serde_json::from_str(&content)?;
230        Ok(Some(result))
231    }
232
233    async fn delete(&self, run_id: &str) -> Result<(), PersistenceError> {
234        let state_path = self.state_path(run_id);
235        let result_path = self.result_path(run_id);
236
237        if state_path.exists() {
238            tokio::fs::remove_file(&state_path).await?;
239        }
240        if result_path.exists() {
241            tokio::fs::remove_file(&result_path).await?;
242        }
243        Ok(())
244    }
245
246    async fn list_runs(&self) -> Result<Vec<String>, PersistenceError> {
247        if !self.directory.exists() {
248            return Ok(Vec::new());
249        }
250
251        let mut runs = std::collections::HashSet::new();
252        let mut entries = tokio::fs::read_dir(&self.directory).await?;
253
254        while let Some(entry) = entries.next_entry().await? {
255            let name = entry.file_name().to_string_lossy().to_string();
256            if let Some(run_id) = name
257                .strip_suffix("_state.json")
258                .or_else(|| name.strip_suffix("_result.json"))
259            {
260                runs.insert(run_id.to_string());
261            }
262        }
263
264        Ok(runs.into_iter().collect())
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use serde::{Deserialize, Serialize};
272
273    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
274    struct TestState {
275        value: i32,
276    }
277
278    #[tokio::test]
279    async fn test_in_memory_persistence() {
280        let persistence: InMemoryPersistence<TestState, String> = InMemoryPersistence::new();
281
282        let state = TestState { value: 42 };
283        persistence.save_state("run1", &state, 5).await.unwrap();
284
285        let loaded = persistence.load_state("run1").await.unwrap();
286        assert!(loaded.is_some());
287        let (loaded_state, step) = loaded.unwrap();
288        assert_eq!(loaded_state.value, 42);
289        assert_eq!(step, 5);
290    }
291
292    #[tokio::test]
293    async fn test_in_memory_result() {
294        let persistence: InMemoryPersistence<TestState, String> = InMemoryPersistence::new();
295
296        persistence
297            .save_result("run1", &"success".to_string())
298            .await
299            .unwrap();
300
301        let loaded = persistence.load_result("run1").await.unwrap();
302        assert_eq!(loaded, Some("success".to_string()));
303    }
304
305    #[tokio::test]
306    async fn test_in_memory_delete() {
307        let persistence: InMemoryPersistence<TestState, String> = InMemoryPersistence::new();
308
309        let state = TestState { value: 1 };
310        persistence.save_state("run1", &state, 1).await.unwrap();
311        persistence.delete("run1").await.unwrap();
312
313        let loaded = persistence.load_state("run1").await.unwrap();
314        assert!(loaded.is_none());
315    }
316
317    #[tokio::test]
318    async fn test_in_memory_list_runs() {
319        let persistence: InMemoryPersistence<TestState, String> = InMemoryPersistence::new();
320
321        let state = TestState { value: 1 };
322        persistence.save_state("run1", &state, 1).await.unwrap();
323        persistence.save_state("run2", &state, 1).await.unwrap();
324
325        let runs = persistence.list_runs().await.unwrap();
326        assert_eq!(runs.len(), 2);
327    }
328
329    #[tokio::test]
330    async fn test_file_persistence() {
331        let temp_dir = std::env::temp_dir().join("serdes_ai_test");
332        let persistence = FilePersistence::new(&temp_dir);
333
334        let state = TestState { value: 42 };
335        StatePersistence::<TestState, String>::save_state(&persistence, "test_run", &state, 5)
336            .await
337            .unwrap();
338
339        let loaded: Option<(TestState, u32)> =
340            StatePersistence::<TestState, String>::load_state(&persistence, "test_run")
341                .await
342                .unwrap();
343        assert!(loaded.is_some());
344        let (loaded_state, step) = loaded.unwrap();
345        assert_eq!(loaded_state.value, 42);
346        assert_eq!(step, 5);
347
348        // Cleanup
349        let _ = StatePersistence::<TestState, String>::delete(&persistence, "test_run").await;
350    }
351}