Skip to main content

synwire_checkpoint/
memory.rs

1//! In-memory checkpoint saver implementation.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7
8use crate::base::BaseCheckpointSaver;
9use crate::types::{
10    Checkpoint, CheckpointConfig, CheckpointError, CheckpointMetadata, CheckpointTuple,
11};
12
13/// An in-memory checkpoint saver backed by a `RwLock<HashMap>`.
14///
15/// Stores checkpoints in memory keyed by thread ID. Suitable for
16/// testing and short-lived processes. Data is lost when the process exits.
17#[derive(Debug, Clone, Default)]
18pub struct InMemoryCheckpointSaver {
19    storage: Arc<RwLock<HashMap<String, Vec<CheckpointTuple>>>>,
20}
21
22impl InMemoryCheckpointSaver {
23    /// Create a new, empty in-memory checkpoint saver.
24    pub fn new() -> Self {
25        Self::default()
26    }
27}
28
29#[allow(clippy::significant_drop_tightening)]
30impl BaseCheckpointSaver for InMemoryCheckpointSaver {
31    fn get_tuple<'a>(
32        &'a self,
33        config: &'a CheckpointConfig,
34    ) -> synwire_core::BoxFuture<'a, Result<Option<CheckpointTuple>, CheckpointError>> {
35        Box::pin(async move {
36            let storage = self.storage.read().await;
37            let Some(tuples) = storage.get(&config.thread_id) else {
38                return Ok(None);
39            };
40            Ok(config.checkpoint_id.as_ref().map_or_else(
41                || tuples.last().cloned(),
42                |checkpoint_id| {
43                    tuples
44                        .iter()
45                        .find(|t| t.checkpoint.id == *checkpoint_id)
46                        .cloned()
47                },
48            ))
49        })
50    }
51
52    fn list<'a>(
53        &'a self,
54        config: &'a CheckpointConfig,
55        limit: Option<usize>,
56    ) -> synwire_core::BoxFuture<'a, Result<Vec<CheckpointTuple>, CheckpointError>> {
57        Box::pin(async move {
58            let storage = self.storage.read().await;
59            let Some(tuples) = storage.get(&config.thread_id) else {
60                return Ok(Vec::new());
61            };
62            let mut result: Vec<CheckpointTuple> = tuples.iter().rev().cloned().collect();
63            if let Some(limit) = limit {
64                result.truncate(limit);
65            }
66            Ok(result)
67        })
68    }
69
70    fn put<'a>(
71        &'a self,
72        config: &'a CheckpointConfig,
73        checkpoint: Checkpoint,
74        metadata: CheckpointMetadata,
75    ) -> synwire_core::BoxFuture<'a, Result<CheckpointConfig, CheckpointError>> {
76        Box::pin(async move {
77            let new_config = CheckpointConfig {
78                thread_id: config.thread_id.clone(),
79                checkpoint_id: Some(checkpoint.id.clone()),
80            };
81
82            let mut storage = self.storage.write().await;
83            let tuples = storage.entry(config.thread_id.clone()).or_default();
84            let parent_config = tuples.last().map(|t| t.config.clone());
85
86            tuples.push(CheckpointTuple {
87                config: new_config.clone(),
88                checkpoint,
89                metadata,
90                parent_config,
91            });
92
93            Ok(new_config)
94        })
95    }
96}
97
98#[cfg(test)]
99#[allow(clippy::unwrap_used)]
100mod tests {
101    use super::*;
102    use crate::types::CheckpointSource;
103    use serde_json::json;
104
105    fn make_checkpoint(id: &str, step: i64) -> (Checkpoint, CheckpointMetadata) {
106        let mut cp = Checkpoint::new(id.to_owned());
107        let _prev = cp.channel_values.insert("messages".into(), json!([]));
108        let metadata = CheckpointMetadata {
109            source: CheckpointSource::Loop,
110            step,
111            writes: HashMap::new(),
112            parents: HashMap::new(),
113        };
114        (cp, metadata)
115    }
116
117    /// T216: `InMemoryCheckpointSaver` put and get round-trip.
118    #[tokio::test]
119    async fn put_and_get_round_trip() {
120        let saver = InMemoryCheckpointSaver::new();
121        let config = CheckpointConfig {
122            thread_id: "thread-1".into(),
123            checkpoint_id: None,
124        };
125        let (cp, meta) = make_checkpoint("cp-1", 0);
126        let result_config = saver.put(&config, cp, meta).await.unwrap();
127        assert_eq!(result_config.checkpoint_id.as_deref(), Some("cp-1"));
128
129        // Get by thread_id (latest)
130        let tuple = saver.get_tuple(&config).await.unwrap().unwrap();
131        assert_eq!(tuple.checkpoint.id, "cp-1");
132        assert_eq!(tuple.checkpoint.channel_values["messages"], json!([]));
133
134        // Get by specific checkpoint_id
135        let specific = CheckpointConfig {
136            thread_id: "thread-1".into(),
137            checkpoint_id: Some("cp-1".into()),
138        };
139        let tuple = saver.get_tuple(&specific).await.unwrap().unwrap();
140        assert_eq!(tuple.checkpoint.id, "cp-1");
141
142        // Get non-existent
143        let missing = CheckpointConfig {
144            thread_id: "no-such-thread".into(),
145            checkpoint_id: None,
146        };
147        assert!(saver.get_tuple(&missing).await.unwrap().is_none());
148    }
149
150    /// T217: list returns in reverse chronological order.
151    #[tokio::test]
152    async fn list_returns_in_order() {
153        let saver = InMemoryCheckpointSaver::new();
154        let config = CheckpointConfig {
155            thread_id: "thread-1".into(),
156            checkpoint_id: None,
157        };
158
159        for i in 0..5 {
160            let (cp, meta) = make_checkpoint(&format!("cp-{i}"), i64::from(i));
161            let _cfg = saver.put(&config, cp, meta).await.unwrap();
162        }
163
164        // List all -- newest first
165        let all = saver.list(&config, None).await.unwrap();
166        assert_eq!(all.len(), 5);
167        assert_eq!(all[0].checkpoint.id, "cp-4");
168        assert_eq!(all[4].checkpoint.id, "cp-0");
169
170        // List with limit
171        let limited = saver.list(&config, Some(2)).await.unwrap();
172        assert_eq!(limited.len(), 2);
173        assert_eq!(limited[0].checkpoint.id, "cp-4");
174        assert_eq!(limited[1].checkpoint.id, "cp-3");
175
176        // Parent config should chain
177        assert!(all[0].parent_config.is_some());
178        assert_eq!(
179            all[0]
180                .parent_config
181                .as_ref()
182                .unwrap()
183                .checkpoint_id
184                .as_deref(),
185            Some("cp-3")
186        );
187    }
188
189    /// T223: `format_version` defaults to "1.0".
190    #[tokio::test]
191    async fn format_version_default() {
192        let cp = Checkpoint::new("test".into());
193        assert_eq!(cp.format_version, "1.0");
194    }
195}