rust_langgraph/checkpoint_backends/
memory.rs1use crate::checkpoint::{BaseCheckpointSaver, Checkpoint, CheckpointMetadata, CheckpointTuple};
8use crate::config::Config;
9use crate::errors::Result;
10use async_trait::async_trait;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15#[derive(Debug, Clone)]
36pub struct MemorySaver {
37 storage: Arc<RwLock<MemoryStorage>>,
38}
39
40#[derive(Debug, Default)]
41struct MemoryStorage {
42 threads: HashMap<String, Vec<CheckpointTuple>>,
44 by_id: HashMap<String, CheckpointTuple>,
46}
47
48impl MemorySaver {
49 pub fn new() -> Self {
51 Self {
52 storage: Arc::new(RwLock::new(MemoryStorage::default())),
53 }
54 }
55
56 pub async fn len(&self) -> usize {
58 let storage = self.storage.read().await;
59 storage.by_id.len()
60 }
61
62 pub async fn is_empty(&self) -> bool {
64 self.len().await == 0
65 }
66
67 pub async fn clear(&self) {
69 let mut storage = self.storage.write().await;
70 storage.threads.clear();
71 storage.by_id.clear();
72 }
73}
74
75impl Default for MemorySaver {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81#[async_trait]
82impl BaseCheckpointSaver for MemorySaver {
83 async fn get_tuple(&self, config: &Config) -> Result<Option<CheckpointTuple>> {
84 let storage = self.storage.read().await;
85
86 if let Some(checkpoint_id) = &config.checkpoint_id {
88 return Ok(storage.by_id.get(checkpoint_id).cloned());
89 }
90
91 if let Some(thread_id) = &config.thread_id {
93 if let Some(tuples) = storage.threads.get(thread_id) {
94 return Ok(tuples.last().cloned());
95 }
96 }
97
98 Ok(None)
99 }
100
101 async fn put(
102 &self,
103 checkpoint: &Checkpoint,
104 metadata: &CheckpointMetadata,
105 config: &Config,
106 ) -> Result<Config> {
107 let mut storage = self.storage.write().await;
108
109 let thread_id = checkpoint
110 .thread_id
111 .clone()
112 .or_else(|| config.thread_id.clone())
113 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
114
115 let tuple = CheckpointTuple {
116 checkpoint: checkpoint.clone(),
117 metadata: metadata.clone(),
118 config: config.clone(),
119 parent: None,
120 };
121
122 storage
124 .by_id
125 .insert(checkpoint.id.clone(), tuple.clone());
126
127 storage
129 .threads
130 .entry(thread_id.clone())
131 .or_default()
132 .push(tuple);
133
134 Ok(config.clone().with_checkpoint_id(&checkpoint.id))
136 }
137
138 async fn list(&self, config: &Config, limit: Option<usize>) -> Result<Vec<CheckpointTuple>> {
139 let storage = self.storage.read().await;
140
141 if let Some(thread_id) = &config.thread_id {
142 if let Some(tuples) = storage.threads.get(thread_id) {
143 let mut result: Vec<_> = tuples.iter().rev().cloned().collect();
144
145 if let Some(limit) = limit {
146 result.truncate(limit);
147 }
148
149 return Ok(result);
150 }
151 }
152
153 Ok(Vec::new())
154 }
155
156 async fn get(&self, checkpoint_id: &str) -> Result<Option<Checkpoint>> {
157 let storage = self.storage.read().await;
158 Ok(storage.by_id.get(checkpoint_id).map(|t| t.checkpoint.clone()))
159 }
160
161 async fn delete_thread(&self, thread_id: &str) -> Result<()> {
162 let mut storage = self.storage.write().await;
163
164 if let Some(tuples) = storage.threads.remove(thread_id) {
166 for tuple in tuples {
168 storage.by_id.remove(&tuple.checkpoint.id);
169 }
170 }
171
172 Ok(())
173 }
174
175 async fn prune(&self, thread_id: &str, keep: usize) -> Result<usize> {
176 let mut storage = self.storage.write().await;
177
178 if let Some(tuples) = storage.threads.get_mut(thread_id) {
179 if tuples.len() <= keep {
180 return Ok(0);
181 }
182
183 let to_remove = tuples.len() - keep;
184 let removed: Vec<_> = tuples.drain(0..to_remove).collect();
185
186 for tuple in &removed {
188 storage.by_id.remove(&tuple.checkpoint.id);
189 }
190
191 Ok(removed.len())
192 } else {
193 Ok(0)
194 }
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use chrono::Utc;
202
203 #[tokio::test]
204 async fn test_memory_saver_basic() {
205 let saver = MemorySaver::new();
206 assert_eq!(saver.len().await, 0);
207 assert!(saver.is_empty().await);
208
209 let mut checkpoint = Checkpoint::new().with_thread_id("test-thread");
210 checkpoint.set_channel("count", serde_json::json!(5));
211
212 let metadata = CheckpointMetadata {
213 created_at: Utc::now(),
214 step: 1,
215 source: "test".to_string(),
216 extra: HashMap::new(),
217 };
218
219 let config = Config::new().with_thread_id("test-thread");
220 let updated_config = saver.put(&checkpoint, &metadata, &config).await.unwrap();
221
222 assert!(updated_config.checkpoint_id.is_some());
223 assert_eq!(saver.len().await, 1);
224 }
225
226 #[tokio::test]
227 async fn test_memory_saver_get_tuple() {
228 let saver = MemorySaver::new();
229
230 let checkpoint = Checkpoint::new().with_thread_id("thread-1");
231 let metadata = CheckpointMetadata::default();
232 let config = Config::new().with_thread_id("thread-1");
233
234 saver.put(&checkpoint, &metadata, &config).await.unwrap();
235
236 let tuple = saver.get_tuple(&config).await.unwrap();
238 assert!(tuple.is_some());
239 assert_eq!(tuple.unwrap().checkpoint.id, checkpoint.id);
240
241 let config_with_id = Config::new().with_checkpoint_id(&checkpoint.id);
243 let tuple = saver.get_tuple(&config_with_id).await.unwrap();
244 assert!(tuple.is_some());
245 }
246
247 #[tokio::test]
248 async fn test_memory_saver_list() {
249 let saver = MemorySaver::new();
250 let config = Config::new().with_thread_id("thread-1");
251
252 for i in 0..5 {
254 let mut checkpoint = Checkpoint::new().with_thread_id("thread-1");
255 checkpoint.set_channel("step", serde_json::json!(i));
256 let metadata = CheckpointMetadata {
257 step: i,
258 ..Default::default()
259 };
260 saver.put(&checkpoint, &metadata, &config).await.unwrap();
261 }
262
263 let list = saver.list(&config, None).await.unwrap();
265 assert_eq!(list.len(), 5);
266
267 assert_eq!(list[0].metadata.step, 4);
269 assert_eq!(list[4].metadata.step, 0);
270
271 let list = saver.list(&config, Some(2)).await.unwrap();
273 assert_eq!(list.len(), 2);
274 assert_eq!(list[0].metadata.step, 4);
275 assert_eq!(list[1].metadata.step, 3);
276 }
277
278 #[tokio::test]
279 async fn test_memory_saver_delete_thread() {
280 let saver = MemorySaver::new();
281
282 let checkpoint1 = Checkpoint::new().with_thread_id("thread-1");
283 let checkpoint2 = Checkpoint::new().with_thread_id("thread-2");
284 let metadata = CheckpointMetadata::default();
285
286 saver
287 .put(
288 &checkpoint1,
289 &metadata,
290 &Config::new().with_thread_id("thread-1"),
291 )
292 .await
293 .unwrap();
294 saver
295 .put(
296 &checkpoint2,
297 &metadata,
298 &Config::new().with_thread_id("thread-2"),
299 )
300 .await
301 .unwrap();
302
303 assert_eq!(saver.len().await, 2);
304
305 saver.delete_thread("thread-1").await.unwrap();
307 assert_eq!(saver.len().await, 1);
308
309 let tuple = saver
311 .get_tuple(&Config::new().with_thread_id("thread-2"))
312 .await
313 .unwrap();
314 assert!(tuple.is_some());
315 }
316
317 #[tokio::test]
318 async fn test_memory_saver_prune() {
319 let saver = MemorySaver::new();
320 let config = Config::new().with_thread_id("thread-1");
321
322 for i in 0..10 {
324 let checkpoint = Checkpoint::new().with_thread_id("thread-1");
325 let metadata = CheckpointMetadata {
326 step: i,
327 ..Default::default()
328 };
329 saver.put(&checkpoint, &metadata, &config).await.unwrap();
330 }
331
332 assert_eq!(saver.len().await, 10);
333
334 let removed = saver.prune("thread-1", 3).await.unwrap();
336 assert_eq!(removed, 7);
337 assert_eq!(saver.len().await, 3);
338
339 let list = saver.list(&config, None).await.unwrap();
341 assert_eq!(list.len(), 3);
342 assert_eq!(list[0].metadata.step, 9); assert_eq!(list[2].metadata.step, 7); }
345
346 #[tokio::test]
347 async fn test_memory_saver_clear() {
348 let saver = MemorySaver::new();
349
350 let checkpoint = Checkpoint::new().with_thread_id("test");
351 let metadata = CheckpointMetadata::default();
352 let config = Config::new().with_thread_id("test");
353
354 saver.put(&checkpoint, &metadata, &config).await.unwrap();
355 assert!(!saver.is_empty().await);
356
357 saver.clear().await;
358 assert!(saver.is_empty().await);
359 }
360}