swink_agent/checkpoint/
store.rs1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4
5use super::Checkpoint;
6
7pub type CheckpointFuture<'a, T> = Pin<Box<dyn Future<Output = io::Result<T>> + Send + 'a>>;
9
10pub trait CheckpointStore: Send + Sync {
14 fn save_checkpoint(&self, checkpoint: Checkpoint) -> CheckpointFuture<'_, ()>;
16
17 fn load_checkpoint(&self, id: &str) -> CheckpointFuture<'_, Option<Checkpoint>>;
19
20 fn list_checkpoints(&self) -> CheckpointFuture<'_, Vec<String>>;
22
23 fn delete_checkpoint(&self, id: &str) -> CheckpointFuture<'_, ()>;
25}
26
27#[cfg(test)]
28mod tests {
29 use std::collections::HashMap;
30 use std::sync::{Mutex, MutexGuard};
31
32 use super::*;
33
34 struct InMemoryCheckpointStore {
35 data: Mutex<HashMap<String, String>>,
36 }
37
38 impl InMemoryCheckpointStore {
39 fn new() -> Self {
40 Self {
41 data: Mutex::new(HashMap::new()),
42 }
43 }
44
45 fn lock_data(&self) -> io::Result<MutexGuard<'_, HashMap<String, String>>> {
46 self.data
47 .lock()
48 .map_err(|error| io::Error::other(error.to_string()))
49 }
50 }
51
52 impl CheckpointStore for InMemoryCheckpointStore {
53 fn save_checkpoint(&self, checkpoint: Checkpoint) -> CheckpointFuture<'_, ()> {
54 let json = serde_json::to_string(&checkpoint).unwrap();
55 let id = checkpoint.id;
56 Box::pin(async move {
57 self.lock_data()?.insert(id, json);
58 Ok(())
59 })
60 }
61
62 fn load_checkpoint(&self, id: &str) -> CheckpointFuture<'_, Option<Checkpoint>> {
63 let id = id.to_string();
64 Box::pin(async move {
65 self.lock_data()?
66 .get(&id)
67 .map(|json| serde_json::from_str(json).map_err(io::Error::other))
68 .transpose()
69 })
70 }
71
72 fn list_checkpoints(&self) -> CheckpointFuture<'_, Vec<String>> {
73 Box::pin(async move { Ok(self.lock_data()?.keys().cloned().collect()) })
74 }
75
76 fn delete_checkpoint(&self, id: &str) -> CheckpointFuture<'_, ()> {
77 let id = id.to_string();
78 Box::pin(async move {
79 self.lock_data()?.remove(&id);
80 Ok(())
81 })
82 }
83 }
84
85 #[tokio::test]
86 async fn in_memory_checkpoint_store_roundtrip() {
87 let store = InMemoryCheckpointStore::new();
88 let checkpoint =
89 Checkpoint::new("cp-store-test", "prompt", "provider", "model", &[]).with_turn_count(2);
90
91 store.save_checkpoint(checkpoint).await.unwrap();
92
93 let ids = store.list_checkpoints().await.unwrap();
94 assert_eq!(ids, vec!["cp-store-test".to_string()]);
95
96 let loaded = store
97 .load_checkpoint("cp-store-test")
98 .await
99 .unwrap()
100 .unwrap();
101 assert_eq!(loaded.id, "cp-store-test");
102 assert_eq!(loaded.turn_count, 2);
103
104 let missing = store.load_checkpoint("nope").await.unwrap();
105 assert!(missing.is_none());
106
107 store.delete_checkpoint("cp-store-test").await.unwrap();
108 assert!(store.list_checkpoints().await.unwrap().is_empty());
109 }
110}