rust_langgraph/checkpoint/
mod.rs1use crate::config::Config;
7use crate::errors::{Error, Result};
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use uuid::Uuid;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Checkpoint {
20 pub v: i32,
22
23 pub id: String,
25
26 pub ts: String,
28
29 pub channel_values: HashMap<String, serde_json::Value>,
31
32 pub channel_versions: HashMap<String, i32>,
34
35 pub versions_seen: HashMap<String, HashMap<String, i32>>,
37
38 pub thread_id: Option<String>,
40
41 pub parent_id: Option<String>,
43}
44
45impl Checkpoint {
46 pub fn new() -> Self {
48 Self {
49 v: 1,
50 id: Uuid::new_v4().to_string(),
51 ts: Utc::now().to_rfc3339(),
52 channel_values: HashMap::new(),
53 channel_versions: HashMap::new(),
54 versions_seen: HashMap::new(),
55 thread_id: None,
56 parent_id: None,
57 }
58 }
59
60 pub fn with_thread_id(mut self, thread_id: impl Into<String>) -> Self {
62 self.thread_id = Some(thread_id.into());
63 self
64 }
65
66 pub fn set_channel(&mut self, name: impl Into<String>, value: serde_json::Value) {
68 let name = name.into();
69 let version = self.channel_versions.get(&name).copied().unwrap_or(0) + 1;
70 self.channel_values.insert(name.clone(), value);
71 self.channel_versions.insert(name, version);
72 }
73
74 pub fn get_channel(&self, name: &str) -> Option<&serde_json::Value> {
76 self.channel_values.get(name)
77 }
78}
79
80impl Default for Checkpoint {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct CheckpointTuple {
89 pub checkpoint: Checkpoint,
91
92 pub metadata: CheckpointMetadata,
94
95 pub config: Config,
97
98 pub parent: Option<Box<CheckpointTuple>>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct CheckpointMetadata {
105 pub created_at: DateTime<Utc>,
107
108 pub step: usize,
110
111 pub source: String,
113
114 pub extra: HashMap<String, serde_json::Value>,
116}
117
118impl Default for CheckpointMetadata {
119 fn default() -> Self {
120 Self {
121 created_at: Utc::now(),
122 step: 0,
123 source: "unknown".to_string(),
124 extra: HashMap::new(),
125 }
126 }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct StateSnapshot<S> {
134 pub state: S,
136
137 pub checkpoint: Checkpoint,
139
140 pub metadata: CheckpointMetadata,
142
143 pub config: Config,
145}
146
147#[async_trait]
152pub trait BaseCheckpointSaver: Send + Sync {
153 async fn get_tuple(&self, config: &Config) -> Result<Option<CheckpointTuple>>;
158
159 async fn put(
163 &self,
164 checkpoint: &Checkpoint,
165 metadata: &CheckpointMetadata,
166 config: &Config,
167 ) -> Result<Config>;
168
169 async fn list(&self, config: &Config, limit: Option<usize>) -> Result<Vec<CheckpointTuple>>;
173
174 async fn get(&self, checkpoint_id: &str) -> Result<Option<Checkpoint>> {
176 let config = Config::new().with_checkpoint_id(checkpoint_id);
178 Ok(self.get_tuple(&config).await?.map(|t| t.checkpoint))
179 }
180
181 async fn delete_thread(&self, thread_id: &str) -> Result<()> {
183 Err(Error::checkpoint(format!(
185 "delete_thread not implemented for thread {}",
186 thread_id
187 )))
188 }
189
190 async fn prune(&self, thread_id: &str, keep: usize) -> Result<usize> {
192 let _ = (thread_id, keep);
194 Err(Error::checkpoint("prune not implemented"))
195 }
196}
197
198pub type CheckpointSaverBox = Box<dyn BaseCheckpointSaver>;
200
201pub type CheckpointSaverArc = std::sync::Arc<dyn BaseCheckpointSaver>;
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn test_checkpoint_creation() {
210 let checkpoint = Checkpoint::new();
211 assert_eq!(checkpoint.v, 1);
212 assert!(!checkpoint.id.is_empty());
213 assert!(checkpoint.channel_values.is_empty());
214 }
215
216 #[test]
217 fn test_checkpoint_with_thread_id() {
218 let checkpoint = Checkpoint::new().with_thread_id("thread-123");
219 assert_eq!(checkpoint.thread_id.as_deref(), Some("thread-123"));
220 }
221
222 #[test]
223 fn test_checkpoint_set_get_channel() {
224 let mut checkpoint = Checkpoint::new();
225 checkpoint.set_channel("my_channel", serde_json::json!({"value": 42}));
226
227 let value = checkpoint.get_channel("my_channel").unwrap();
228 assert_eq!(value, &serde_json::json!({"value": 42}));
229
230 let version = checkpoint.channel_versions.get("my_channel").unwrap();
231 assert_eq!(*version, 1);
232
233 checkpoint.set_channel("my_channel", serde_json::json!({"value": 43}));
235 let version = checkpoint.channel_versions.get("my_channel").unwrap();
236 assert_eq!(*version, 2);
237 }
238
239 #[test]
240 fn test_checkpoint_serialization() {
241 let mut checkpoint = Checkpoint::new().with_thread_id("test");
242 checkpoint.set_channel("count", serde_json::json!(5));
243
244 let json = serde_json::to_string(&checkpoint).unwrap();
245 let deserialized: Checkpoint = serde_json::from_str(&json).unwrap();
246
247 assert_eq!(deserialized.thread_id, checkpoint.thread_id);
248 assert_eq!(
249 deserialized.get_channel("count"),
250 checkpoint.get_channel("count")
251 );
252 }
253
254 #[test]
255 fn test_checkpoint_metadata() {
256 let metadata = CheckpointMetadata {
257 step: 5,
258 source: "test".to_string(),
259 ..Default::default()
260 };
261
262 assert_eq!(metadata.step, 5);
263 assert_eq!(metadata.source, "test");
264 }
265}