Skip to main content

rust_langgraph/checkpoint_backends/
memory.rs

1//! In-memory checkpoint backend.
2//!
3//! A simple checkpoint saver that stores checkpoints in memory.
4//! Useful for development and testing, but checkpoints are lost
5//! when the process exits.
6
7use crate::checkpoint::{BaseCheckpointSaver, Checkpoint, CheckpointMetadata, CheckpointTuple};
8use crate::config::Config;
9use crate::errors::Result;
10use async_trait::async_trait;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15/// In-memory checkpoint storage.
16///
17/// Stores checkpoints in a HashMap in memory. Fast and simple,
18/// but does not persist across process restarts.
19///
20/// # Example
21///
22/// ```rust
23/// use rust_langgraph::checkpoint_backends::memory::MemorySaver;
24/// use rust_langgraph::checkpoint::BaseCheckpointSaver;
25/// use rust_langgraph::Config;
26///
27/// #[tokio::main]
28/// async fn main() {
29///     let saver = MemorySaver::new();
30///     let config = Config::new().with_thread_id("test-123");
31///     
32///     // Use saver with graph...
33/// }
34/// ```
35#[derive(Debug, Clone)]
36pub struct MemorySaver {
37    storage: Arc<RwLock<MemoryStorage>>,
38}
39
40#[derive(Debug, Default)]
41struct MemoryStorage {
42    // Map from thread_id -> list of checkpoint tuples (oldest to newest)
43    threads: HashMap<String, Vec<CheckpointTuple>>,
44    // Map from checkpoint_id -> checkpoint tuple
45    by_id: HashMap<String, CheckpointTuple>,
46}
47
48impl MemorySaver {
49    /// Create a new in-memory checkpoint saver
50    pub fn new() -> Self {
51        Self {
52            storage: Arc::new(RwLock::new(MemoryStorage::default())),
53        }
54    }
55
56    /// Get the number of checkpoints stored
57    pub async fn len(&self) -> usize {
58        let storage = self.storage.read().await;
59        storage.by_id.len()
60    }
61
62    /// Check if storage is empty
63    pub async fn is_empty(&self) -> bool {
64        self.len().await == 0
65    }
66
67    /// Clear all stored checkpoints
68    pub async fn clear(&self) {
69        let mut storage = self.storage.write().await;
70        storage.threads.clear();
71        storage.by_id.clear();
72    }
73}
74
75impl Default for MemorySaver {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81#[async_trait]
82impl BaseCheckpointSaver for MemorySaver {
83    async fn get_tuple(&self, config: &Config) -> Result<Option<CheckpointTuple>> {
84        let storage = self.storage.read().await;
85
86        // If a specific checkpoint ID is requested
87        if let Some(checkpoint_id) = &config.checkpoint_id {
88            return Ok(storage.by_id.get(checkpoint_id).cloned());
89        }
90
91        // Otherwise, return the latest checkpoint for the thread
92        if let Some(thread_id) = &config.thread_id {
93            if let Some(tuples) = storage.threads.get(thread_id) {
94                return Ok(tuples.last().cloned());
95            }
96        }
97
98        Ok(None)
99    }
100
101    async fn put(
102        &self,
103        checkpoint: &Checkpoint,
104        metadata: &CheckpointMetadata,
105        config: &Config,
106    ) -> Result<Config> {
107        let mut storage = self.storage.write().await;
108
109        let thread_id = checkpoint
110            .thread_id
111            .clone()
112            .or_else(|| config.thread_id.clone())
113            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
114
115        let tuple = CheckpointTuple {
116            checkpoint: checkpoint.clone(),
117            metadata: metadata.clone(),
118            config: config.clone(),
119            parent: None,
120        };
121
122        // Store by ID
123        storage
124            .by_id
125            .insert(checkpoint.id.clone(), tuple.clone());
126
127        // Store in thread history
128        storage
129            .threads
130            .entry(thread_id.clone())
131            .or_default()
132            .push(tuple);
133
134        // Return config with checkpoint ID set
135        Ok(config.clone().with_checkpoint_id(&checkpoint.id))
136    }
137
138    async fn list(&self, config: &Config, limit: Option<usize>) -> Result<Vec<CheckpointTuple>> {
139        let storage = self.storage.read().await;
140
141        if let Some(thread_id) = &config.thread_id {
142            if let Some(tuples) = storage.threads.get(thread_id) {
143                let mut result: Vec<_> = tuples.iter().rev().cloned().collect();
144
145                if let Some(limit) = limit {
146                    result.truncate(limit);
147                }
148
149                return Ok(result);
150            }
151        }
152
153        Ok(Vec::new())
154    }
155
156    async fn get(&self, checkpoint_id: &str) -> Result<Option<Checkpoint>> {
157        let storage = self.storage.read().await;
158        Ok(storage.by_id.get(checkpoint_id).map(|t| t.checkpoint.clone()))
159    }
160
161    async fn delete_thread(&self, thread_id: &str) -> Result<()> {
162        let mut storage = self.storage.write().await;
163
164        // Remove from thread list
165        if let Some(tuples) = storage.threads.remove(thread_id) {
166            // Remove all checkpoint IDs from by_id map
167            for tuple in tuples {
168                storage.by_id.remove(&tuple.checkpoint.id);
169            }
170        }
171
172        Ok(())
173    }
174
175    async fn prune(&self, thread_id: &str, keep: usize) -> Result<usize> {
176        let mut storage = self.storage.write().await;
177
178        if let Some(tuples) = storage.threads.get_mut(thread_id) {
179            if tuples.len() <= keep {
180                return Ok(0);
181            }
182
183            let to_remove = tuples.len() - keep;
184            let removed: Vec<_> = tuples.drain(0..to_remove).collect();
185
186            // Remove from by_id map
187            for tuple in &removed {
188                storage.by_id.remove(&tuple.checkpoint.id);
189            }
190
191            Ok(removed.len())
192        } else {
193            Ok(0)
194        }
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use chrono::Utc;
202
203    #[tokio::test]
204    async fn test_memory_saver_basic() {
205        let saver = MemorySaver::new();
206        assert_eq!(saver.len().await, 0);
207        assert!(saver.is_empty().await);
208
209        let mut checkpoint = Checkpoint::new().with_thread_id("test-thread");
210        checkpoint.set_channel("count", serde_json::json!(5));
211
212        let metadata = CheckpointMetadata {
213            created_at: Utc::now(),
214            step: 1,
215            source: "test".to_string(),
216            extra: HashMap::new(),
217        };
218
219        let config = Config::new().with_thread_id("test-thread");
220        let updated_config = saver.put(&checkpoint, &metadata, &config).await.unwrap();
221
222        assert!(updated_config.checkpoint_id.is_some());
223        assert_eq!(saver.len().await, 1);
224    }
225
226    #[tokio::test]
227    async fn test_memory_saver_get_tuple() {
228        let saver = MemorySaver::new();
229
230        let checkpoint = Checkpoint::new().with_thread_id("thread-1");
231        let metadata = CheckpointMetadata::default();
232        let config = Config::new().with_thread_id("thread-1");
233
234        saver.put(&checkpoint, &metadata, &config).await.unwrap();
235
236        // Get latest for thread
237        let tuple = saver.get_tuple(&config).await.unwrap();
238        assert!(tuple.is_some());
239        assert_eq!(tuple.unwrap().checkpoint.id, checkpoint.id);
240
241        // Get by checkpoint ID
242        let config_with_id = Config::new().with_checkpoint_id(&checkpoint.id);
243        let tuple = saver.get_tuple(&config_with_id).await.unwrap();
244        assert!(tuple.is_some());
245    }
246
247    #[tokio::test]
248    async fn test_memory_saver_list() {
249        let saver = MemorySaver::new();
250        let config = Config::new().with_thread_id("thread-1");
251
252        // Add multiple checkpoints
253        for i in 0..5 {
254            let mut checkpoint = Checkpoint::new().with_thread_id("thread-1");
255            checkpoint.set_channel("step", serde_json::json!(i));
256            let metadata = CheckpointMetadata {
257                step: i,
258                ..Default::default()
259            };
260            saver.put(&checkpoint, &metadata, &config).await.unwrap();
261        }
262
263        // List all
264        let list = saver.list(&config, None).await.unwrap();
265        assert_eq!(list.len(), 5);
266
267        // Checkpoints should be in reverse order (newest first)
268        assert_eq!(list[0].metadata.step, 4);
269        assert_eq!(list[4].metadata.step, 0);
270
271        // List with limit
272        let list = saver.list(&config, Some(2)).await.unwrap();
273        assert_eq!(list.len(), 2);
274        assert_eq!(list[0].metadata.step, 4);
275        assert_eq!(list[1].metadata.step, 3);
276    }
277
278    #[tokio::test]
279    async fn test_memory_saver_delete_thread() {
280        let saver = MemorySaver::new();
281
282        let checkpoint1 = Checkpoint::new().with_thread_id("thread-1");
283        let checkpoint2 = Checkpoint::new().with_thread_id("thread-2");
284        let metadata = CheckpointMetadata::default();
285
286        saver
287            .put(
288                &checkpoint1,
289                &metadata,
290                &Config::new().with_thread_id("thread-1"),
291            )
292            .await
293            .unwrap();
294        saver
295            .put(
296                &checkpoint2,
297                &metadata,
298                &Config::new().with_thread_id("thread-2"),
299            )
300            .await
301            .unwrap();
302
303        assert_eq!(saver.len().await, 2);
304
305        // Delete thread-1
306        saver.delete_thread("thread-1").await.unwrap();
307        assert_eq!(saver.len().await, 1);
308
309        // thread-2 should still exist
310        let tuple = saver
311            .get_tuple(&Config::new().with_thread_id("thread-2"))
312            .await
313            .unwrap();
314        assert!(tuple.is_some());
315    }
316
317    #[tokio::test]
318    async fn test_memory_saver_prune() {
319        let saver = MemorySaver::new();
320        let config = Config::new().with_thread_id("thread-1");
321
322        // Add 10 checkpoints
323        for i in 0..10 {
324            let checkpoint = Checkpoint::new().with_thread_id("thread-1");
325            let metadata = CheckpointMetadata {
326                step: i,
327                ..Default::default()
328            };
329            saver.put(&checkpoint, &metadata, &config).await.unwrap();
330        }
331
332        assert_eq!(saver.len().await, 10);
333
334        // Keep only 3 most recent
335        let removed = saver.prune("thread-1", 3).await.unwrap();
336        assert_eq!(removed, 7);
337        assert_eq!(saver.len().await, 3);
338
339        // Verify we kept the most recent ones
340        let list = saver.list(&config, None).await.unwrap();
341        assert_eq!(list.len(), 3);
342        assert_eq!(list[0].metadata.step, 9); // newest
343        assert_eq!(list[2].metadata.step, 7); // oldest of the kept ones
344    }
345
346    #[tokio::test]
347    async fn test_memory_saver_clear() {
348        let saver = MemorySaver::new();
349
350        let checkpoint = Checkpoint::new().with_thread_id("test");
351        let metadata = CheckpointMetadata::default();
352        let config = Config::new().with_thread_id("test");
353
354        saver.put(&checkpoint, &metadata, &config).await.unwrap();
355        assert!(!saver.is_empty().await);
356
357        saver.clear().await;
358        assert!(saver.is_empty().await);
359    }
360}