1use crate::error::GraphError;
4use async_trait::async_trait;
5use parking_lot::RwLock;
6use serde::{de::DeserializeOwned, Serialize};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::Arc;
10
11#[derive(Debug, thiserror::Error)]
13pub enum PersistenceError {
14 #[error("IO error: {0}")]
16 Io(#[from] std::io::Error),
17
18 #[error("Serialization error: {0}")]
20 Serialization(#[from] serde_json::Error),
21
22 #[error("State not found for run: {0}")]
24 NotFound(String),
25
26 #[error("{0}")]
28 Other(String),
29}
30
31impl From<PersistenceError> for GraphError {
32 fn from(e: PersistenceError) -> Self {
33 GraphError::Persistence(e.to_string())
34 }
35}
36
37#[async_trait]
39pub trait StatePersistence<State, End>: Send + Sync {
40 async fn save_state(
42 &self,
43 run_id: &str,
44 state: &State,
45 step: u32,
46 ) -> Result<(), PersistenceError>;
47
48 async fn load_state(&self, run_id: &str) -> Result<Option<(State, u32)>, PersistenceError>;
50
51 async fn save_result(&self, run_id: &str, result: &End) -> Result<(), PersistenceError>;
53
54 async fn load_result(&self, run_id: &str) -> Result<Option<End>, PersistenceError>;
56
57 async fn delete(&self, run_id: &str) -> Result<(), PersistenceError>;
59
60 async fn list_runs(&self) -> Result<Vec<String>, PersistenceError>;
62}
63
64#[derive(Clone)]
66pub struct InMemoryPersistence<State, End> {
67 states: Arc<RwLock<HashMap<String, (State, u32)>>>,
68 results: Arc<RwLock<HashMap<String, End>>>,
69}
70
71impl<State, End> InMemoryPersistence<State, End> {
72 pub fn new() -> Self {
74 Self {
75 states: Arc::new(RwLock::new(HashMap::new())),
76 results: Arc::new(RwLock::new(HashMap::new())),
77 }
78 }
79
80 pub fn clear(&self) {
82 self.states.write().clear();
83 self.results.write().clear();
84 }
85
86 pub fn state_count(&self) -> usize {
88 self.states.read().len()
89 }
90
91 pub fn result_count(&self) -> usize {
93 self.results.read().len()
94 }
95}
96
97impl<State, End> Default for InMemoryPersistence<State, End> {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103#[async_trait]
104impl<State, End> StatePersistence<State, End> for InMemoryPersistence<State, End>
105where
106 State: Clone + Send + Sync + 'static,
107 End: Clone + Send + Sync + 'static,
108{
109 async fn save_state(
110 &self,
111 run_id: &str,
112 state: &State,
113 step: u32,
114 ) -> Result<(), PersistenceError> {
115 self.states
116 .write()
117 .insert(run_id.to_string(), (state.clone(), step));
118 Ok(())
119 }
120
121 async fn load_state(&self, run_id: &str) -> Result<Option<(State, u32)>, PersistenceError> {
122 Ok(self.states.read().get(run_id).cloned())
123 }
124
125 async fn save_result(&self, run_id: &str, result: &End) -> Result<(), PersistenceError> {
126 self.results
127 .write()
128 .insert(run_id.to_string(), result.clone());
129 Ok(())
130 }
131
132 async fn load_result(&self, run_id: &str) -> Result<Option<End>, PersistenceError> {
133 Ok(self.results.read().get(run_id).cloned())
134 }
135
136 async fn delete(&self, run_id: &str) -> Result<(), PersistenceError> {
137 self.states.write().remove(run_id);
138 self.results.write().remove(run_id);
139 Ok(())
140 }
141
142 async fn list_runs(&self) -> Result<Vec<String>, PersistenceError> {
143 let state_keys: std::collections::HashSet<_> = self.states.read().keys().cloned().collect();
144 let result_keys: std::collections::HashSet<_> =
145 self.results.read().keys().cloned().collect();
146 Ok(state_keys.union(&result_keys).cloned().collect())
147 }
148}
149
150pub struct FilePersistence {
152 directory: PathBuf,
153}
154
155impl FilePersistence {
156 pub fn new(directory: impl Into<PathBuf>) -> Self {
158 Self {
159 directory: directory.into(),
160 }
161 }
162
163 pub async fn ensure_dir(&self) -> Result<(), PersistenceError> {
165 tokio::fs::create_dir_all(&self.directory).await?;
166 Ok(())
167 }
168
169 fn state_path(&self, run_id: &str) -> PathBuf {
170 self.directory.join(format!("{}_state.json", run_id))
171 }
172
173 fn result_path(&self, run_id: &str) -> PathBuf {
174 self.directory.join(format!("{}_result.json", run_id))
175 }
176}
177
178#[async_trait]
179impl<State, End> StatePersistence<State, End> for FilePersistence
180where
181 State: Serialize + DeserializeOwned + Send + Sync + 'static,
182 End: Serialize + DeserializeOwned + Send + Sync + 'static,
183{
184 async fn save_state(
185 &self,
186 run_id: &str,
187 state: &State,
188 step: u32,
189 ) -> Result<(), PersistenceError> {
190 self.ensure_dir().await?;
191 let path = self.state_path(run_id);
192 let data = serde_json::json!({
193 "state": state,
194 "step": step
195 });
196 let content = serde_json::to_string_pretty(&data)?;
197 tokio::fs::write(&path, content).await?;
198 Ok(())
199 }
200
201 async fn load_state(&self, run_id: &str) -> Result<Option<(State, u32)>, PersistenceError> {
202 let path = self.state_path(run_id);
203 if !path.exists() {
204 return Ok(None);
205 }
206
207 let content = tokio::fs::read_to_string(&path).await?;
208 let value: serde_json::Value = serde_json::from_str(&content)?;
209 let state: State = serde_json::from_value(value["state"].clone())?;
210 let step = value["step"].as_u64().unwrap_or(0) as u32;
211 Ok(Some((state, step)))
212 }
213
214 async fn save_result(&self, run_id: &str, result: &End) -> Result<(), PersistenceError> {
215 self.ensure_dir().await?;
216 let path = self.result_path(run_id);
217 let content = serde_json::to_string_pretty(result)?;
218 tokio::fs::write(&path, content).await?;
219 Ok(())
220 }
221
222 async fn load_result(&self, run_id: &str) -> Result<Option<End>, PersistenceError> {
223 let path = self.result_path(run_id);
224 if !path.exists() {
225 return Ok(None);
226 }
227
228 let content = tokio::fs::read_to_string(&path).await?;
229 let result: End = serde_json::from_str(&content)?;
230 Ok(Some(result))
231 }
232
233 async fn delete(&self, run_id: &str) -> Result<(), PersistenceError> {
234 let state_path = self.state_path(run_id);
235 let result_path = self.result_path(run_id);
236
237 if state_path.exists() {
238 tokio::fs::remove_file(&state_path).await?;
239 }
240 if result_path.exists() {
241 tokio::fs::remove_file(&result_path).await?;
242 }
243 Ok(())
244 }
245
246 async fn list_runs(&self) -> Result<Vec<String>, PersistenceError> {
247 if !self.directory.exists() {
248 return Ok(Vec::new());
249 }
250
251 let mut runs = std::collections::HashSet::new();
252 let mut entries = tokio::fs::read_dir(&self.directory).await?;
253
254 while let Some(entry) = entries.next_entry().await? {
255 let name = entry.file_name().to_string_lossy().to_string();
256 if let Some(run_id) = name
257 .strip_suffix("_state.json")
258 .or_else(|| name.strip_suffix("_result.json"))
259 {
260 runs.insert(run_id.to_string());
261 }
262 }
263
264 Ok(runs.into_iter().collect())
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use serde::{Deserialize, Serialize};
272
273 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
274 struct TestState {
275 value: i32,
276 }
277
278 #[tokio::test]
279 async fn test_in_memory_persistence() {
280 let persistence: InMemoryPersistence<TestState, String> = InMemoryPersistence::new();
281
282 let state = TestState { value: 42 };
283 persistence.save_state("run1", &state, 5).await.unwrap();
284
285 let loaded = persistence.load_state("run1").await.unwrap();
286 assert!(loaded.is_some());
287 let (loaded_state, step) = loaded.unwrap();
288 assert_eq!(loaded_state.value, 42);
289 assert_eq!(step, 5);
290 }
291
292 #[tokio::test]
293 async fn test_in_memory_result() {
294 let persistence: InMemoryPersistence<TestState, String> = InMemoryPersistence::new();
295
296 persistence
297 .save_result("run1", &"success".to_string())
298 .await
299 .unwrap();
300
301 let loaded = persistence.load_result("run1").await.unwrap();
302 assert_eq!(loaded, Some("success".to_string()));
303 }
304
305 #[tokio::test]
306 async fn test_in_memory_delete() {
307 let persistence: InMemoryPersistence<TestState, String> = InMemoryPersistence::new();
308
309 let state = TestState { value: 1 };
310 persistence.save_state("run1", &state, 1).await.unwrap();
311 persistence.delete("run1").await.unwrap();
312
313 let loaded = persistence.load_state("run1").await.unwrap();
314 assert!(loaded.is_none());
315 }
316
317 #[tokio::test]
318 async fn test_in_memory_list_runs() {
319 let persistence: InMemoryPersistence<TestState, String> = InMemoryPersistence::new();
320
321 let state = TestState { value: 1 };
322 persistence.save_state("run1", &state, 1).await.unwrap();
323 persistence.save_state("run2", &state, 1).await.unwrap();
324
325 let runs = persistence.list_runs().await.unwrap();
326 assert_eq!(runs.len(), 2);
327 }
328
329 #[tokio::test]
330 async fn test_file_persistence() {
331 let temp_dir = std::env::temp_dir().join("serdes_ai_test");
332 let persistence = FilePersistence::new(&temp_dir);
333
334 let state = TestState { value: 42 };
335 StatePersistence::<TestState, String>::save_state(&persistence, "test_run", &state, 5)
336 .await
337 .unwrap();
338
339 let loaded: Option<(TestState, u32)> =
340 StatePersistence::<TestState, String>::load_state(&persistence, "test_run")
341 .await
342 .unwrap();
343 assert!(loaded.is_some());
344 let (loaded_state, step) = loaded.unwrap();
345 assert_eq!(loaded_state.value, 42);
346 assert_eq!(step, 5);
347
348 let _ = StatePersistence::<TestState, String>::delete(&persistence, "test_run").await;
350 }
351}