Skip to main content

swink_agent/checkpoint/
store.rs

1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4
5use super::Checkpoint;
6
7/// A boxed future returned by [`CheckpointStore`] methods.
8pub type CheckpointFuture<'a, T> = Pin<Box<dyn Future<Output = io::Result<T>> + Send + 'a>>;
9
10/// Async trait for persisting and loading agent checkpoints.
11///
12/// Implementations can back onto any storage: filesystem, database, cloud, etc.
13pub trait CheckpointStore: Send + Sync {
14    /// Save a checkpoint. Overwrites any existing checkpoint with the same ID.
15    fn save_checkpoint(&self, checkpoint: Checkpoint) -> CheckpointFuture<'_, ()>;
16
17    /// Load a checkpoint by ID.
18    fn load_checkpoint(&self, id: &str) -> CheckpointFuture<'_, Option<Checkpoint>>;
19
20    /// List all checkpoint IDs, most recent first.
21    fn list_checkpoints(&self) -> CheckpointFuture<'_, Vec<String>>;
22
23    /// Delete a checkpoint by ID.
24    fn delete_checkpoint(&self, id: &str) -> CheckpointFuture<'_, ()>;
25}
26
27#[cfg(test)]
28mod tests {
29    use std::collections::HashMap;
30    use std::sync::{Mutex, MutexGuard};
31
32    use super::*;
33
34    struct InMemoryCheckpointStore {
35        data: Mutex<HashMap<String, String>>,
36    }
37
38    impl InMemoryCheckpointStore {
39        fn new() -> Self {
40            Self {
41                data: Mutex::new(HashMap::new()),
42            }
43        }
44
45        fn lock_data(&self) -> io::Result<MutexGuard<'_, HashMap<String, String>>> {
46            self.data
47                .lock()
48                .map_err(|error| io::Error::other(error.to_string()))
49        }
50    }
51
52    impl CheckpointStore for InMemoryCheckpointStore {
53        fn save_checkpoint(&self, checkpoint: Checkpoint) -> CheckpointFuture<'_, ()> {
54            let json = serde_json::to_string(&checkpoint).unwrap();
55            let id = checkpoint.id;
56            Box::pin(async move {
57                self.lock_data()?.insert(id, json);
58                Ok(())
59            })
60        }
61
62        fn load_checkpoint(&self, id: &str) -> CheckpointFuture<'_, Option<Checkpoint>> {
63            let id = id.to_string();
64            Box::pin(async move {
65                self.lock_data()?
66                    .get(&id)
67                    .map(|json| serde_json::from_str(json).map_err(io::Error::other))
68                    .transpose()
69            })
70        }
71
72        fn list_checkpoints(&self) -> CheckpointFuture<'_, Vec<String>> {
73            Box::pin(async move { Ok(self.lock_data()?.keys().cloned().collect()) })
74        }
75
76        fn delete_checkpoint(&self, id: &str) -> CheckpointFuture<'_, ()> {
77            let id = id.to_string();
78            Box::pin(async move {
79                self.lock_data()?.remove(&id);
80                Ok(())
81            })
82        }
83    }
84
85    #[tokio::test]
86    async fn in_memory_checkpoint_store_roundtrip() {
87        let store = InMemoryCheckpointStore::new();
88        let checkpoint =
89            Checkpoint::new("cp-store-test", "prompt", "provider", "model", &[]).with_turn_count(2);
90
91        store.save_checkpoint(checkpoint).await.unwrap();
92
93        let ids = store.list_checkpoints().await.unwrap();
94        assert_eq!(ids, vec!["cp-store-test".to_string()]);
95
96        let loaded = store
97            .load_checkpoint("cp-store-test")
98            .await
99            .unwrap()
100            .unwrap();
101        assert_eq!(loaded.id, "cp-store-test");
102        assert_eq!(loaded.turn_count, 2);
103
104        let missing = store.load_checkpoint("nope").await.unwrap();
105        assert!(missing.is_none());
106
107        store.delete_checkpoint("cp-store-test").await.unwrap();
108        assert!(store.list_checkpoints().await.unwrap().is_empty());
109    }
110}