Skip to main content

rust_langgraph/checkpoint/
mod.rs

1//! Checkpoint system for graph state persistence.
2//!
3//! Checkpoints allow graphs to save and restore execution state,
4//! enabling features like pause/resume, time travel, and crash recovery.
5
6use crate::config::Config;
7use crate::errors::{Error, Result};
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use uuid::Uuid;
13
14/// A checkpoint representing the state of a graph at a point in time.
15///
16/// Checkpoints are compatible with the Python LangGraph wire format,
17/// allowing interoperability between Rust and Python implementations.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Checkpoint {
20    /// Checkpoint format version
21    pub v: i32,
22
23    /// Unique checkpoint identifier
24    pub id: String,
25
26    /// Timestamp when checkpoint was created
27    pub ts: String,
28
29    /// The values of all channels at this checkpoint
30    pub channel_values: HashMap<String, serde_json::Value>,
31
32    /// Version numbers for each channel
33    pub channel_versions: HashMap<String, i32>,
34
35    /// Versions seen by each channel (for tracking updates)
36    pub versions_seen: HashMap<String, HashMap<String, i32>>,
37
38    /// Thread ID this checkpoint belongs to
39    pub thread_id: Option<String>,
40
41    /// Parent checkpoint ID (for nested graphs/subgraphs)
42    pub parent_id: Option<String>,
43}
44
45impl Checkpoint {
46    /// Create a new empty checkpoint
47    pub fn new() -> Self {
48        Self {
49            v: 1,
50            id: Uuid::new_v4().to_string(),
51            ts: Utc::now().to_rfc3339(),
52            channel_values: HashMap::new(),
53            channel_versions: HashMap::new(),
54            versions_seen: HashMap::new(),
55            thread_id: None,
56            parent_id: None,
57        }
58    }
59
60    /// Create a checkpoint with a specific thread ID
61    pub fn with_thread_id(mut self, thread_id: impl Into<String>) -> Self {
62        self.thread_id = Some(thread_id.into());
63        self
64    }
65
66    /// Set a channel value
67    pub fn set_channel(&mut self, name: impl Into<String>, value: serde_json::Value) {
68        let name = name.into();
69        let version = self.channel_versions.get(&name).copied().unwrap_or(0) + 1;
70        self.channel_values.insert(name.clone(), value);
71        self.channel_versions.insert(name, version);
72    }
73
74    /// Get a channel value
75    pub fn get_channel(&self, name: &str) -> Option<&serde_json::Value> {
76        self.channel_values.get(name)
77    }
78}
79
80impl Default for Checkpoint {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86/// A checkpoint along with metadata about when and where it was saved.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct CheckpointTuple {
89    /// The checkpoint itself
90    pub checkpoint: Checkpoint,
91
92    /// Metadata about the checkpoint
93    pub metadata: CheckpointMetadata,
94
95    /// The configuration used for this checkpoint
96    pub config: Config,
97
98    /// Parent checkpoint tuple (for nested graphs)
99    pub parent: Option<Box<CheckpointTuple>>,
100}
101
102/// Metadata about a checkpoint.
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct CheckpointMetadata {
105    /// When the checkpoint was created
106    pub created_at: DateTime<Utc>,
107
108    /// Step number when checkpoint was created
109    pub step: usize,
110
111    /// Source of the checkpoint (e.g., "pregel", "user")
112    pub source: String,
113
114    /// Additional custom metadata
115    pub extra: HashMap<String, serde_json::Value>,
116}
117
118impl Default for CheckpointMetadata {
119    fn default() -> Self {
120        Self {
121            created_at: Utc::now(),
122            step: 0,
123            source: "unknown".to_string(),
124            extra: HashMap::new(),
125        }
126    }
127}
128
129/// A snapshot of graph state at a specific checkpoint.
130///
131/// This includes both the checkpoint data and the deserialized state.
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct StateSnapshot<S> {
134    /// The state at this snapshot
135    pub state: S,
136
137    /// The checkpoint
138    pub checkpoint: Checkpoint,
139
140    /// Metadata
141    pub metadata: CheckpointMetadata,
142
143    /// Configuration
144    pub config: Config,
145}
146
147/// Trait for checkpoint storage backends.
148///
149/// Implementations of this trait provide persistent storage for checkpoints,
150/// enabling save/resume functionality across process restarts.
151#[async_trait]
152pub trait BaseCheckpointSaver: Send + Sync {
153    /// Get a checkpoint tuple for the given configuration.
154    ///
155    /// If `config.checkpoint_id` is set, returns that specific checkpoint.
156    /// Otherwise, returns the latest checkpoint for the thread.
157    async fn get_tuple(&self, config: &Config) -> Result<Option<CheckpointTuple>>;
158
159    /// Save a checkpoint.
160    ///
161    /// Returns an updated Config with the checkpoint ID set.
162    async fn put(
163        &self,
164        checkpoint: &Checkpoint,
165        metadata: &CheckpointMetadata,
166        config: &Config,
167    ) -> Result<Config>;
168
169    /// List checkpoints for a given configuration.
170    ///
171    /// Returns checkpoints in reverse chronological order.
172    async fn list(&self, config: &Config, limit: Option<usize>) -> Result<Vec<CheckpointTuple>>;
173
174    /// Get a specific checkpoint by ID
175    async fn get(&self, checkpoint_id: &str) -> Result<Option<Checkpoint>> {
176        // Default implementation uses get_tuple
177        let config = Config::new().with_checkpoint_id(checkpoint_id);
178        Ok(self.get_tuple(&config).await?.map(|t| t.checkpoint))
179    }
180
181    /// Delete all checkpoints for a thread
182    async fn delete_thread(&self, thread_id: &str) -> Result<()> {
183        // Default implementation returns not implemented error
184        Err(Error::checkpoint(format!(
185            "delete_thread not implemented for thread {}",
186            thread_id
187        )))
188    }
189
190    /// Prune old checkpoints, keeping only the most recent ones
191    async fn prune(&self, thread_id: &str, keep: usize) -> Result<usize> {
192        // Default implementation returns not implemented error
193        let _ = (thread_id, keep);
194        Err(Error::checkpoint("prune not implemented"))
195    }
196}
197
198/// Type alias for boxed checkpoint savers
199pub type CheckpointSaverBox = Box<dyn BaseCheckpointSaver>;
200
201/// Type alias for Arc'd checkpoint savers
202pub type CheckpointSaverArc = std::sync::Arc<dyn BaseCheckpointSaver>;
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn test_checkpoint_creation() {
210        let checkpoint = Checkpoint::new();
211        assert_eq!(checkpoint.v, 1);
212        assert!(!checkpoint.id.is_empty());
213        assert!(checkpoint.channel_values.is_empty());
214    }
215
216    #[test]
217    fn test_checkpoint_with_thread_id() {
218        let checkpoint = Checkpoint::new().with_thread_id("thread-123");
219        assert_eq!(checkpoint.thread_id.as_deref(), Some("thread-123"));
220    }
221
222    #[test]
223    fn test_checkpoint_set_get_channel() {
224        let mut checkpoint = Checkpoint::new();
225        checkpoint.set_channel("my_channel", serde_json::json!({"value": 42}));
226
227        let value = checkpoint.get_channel("my_channel").unwrap();
228        assert_eq!(value, &serde_json::json!({"value": 42}));
229
230        let version = checkpoint.channel_versions.get("my_channel").unwrap();
231        assert_eq!(*version, 1);
232
233        // Update the same channel
234        checkpoint.set_channel("my_channel", serde_json::json!({"value": 43}));
235        let version = checkpoint.channel_versions.get("my_channel").unwrap();
236        assert_eq!(*version, 2);
237    }
238
239    #[test]
240    fn test_checkpoint_serialization() {
241        let mut checkpoint = Checkpoint::new().with_thread_id("test");
242        checkpoint.set_channel("count", serde_json::json!(5));
243
244        let json = serde_json::to_string(&checkpoint).unwrap();
245        let deserialized: Checkpoint = serde_json::from_str(&json).unwrap();
246
247        assert_eq!(deserialized.thread_id, checkpoint.thread_id);
248        assert_eq!(
249            deserialized.get_channel("count"),
250            checkpoint.get_channel("count")
251        );
252    }
253
254    #[test]
255    fn test_checkpoint_metadata() {
256        let metadata = CheckpointMetadata {
257            step: 5,
258            source: "test".to_string(),
259            ..Default::default()
260        };
261
262        assert_eq!(metadata.step, 5);
263        assert_eq!(metadata.source, "test");
264    }
265}