sh_layer2/checkpoint_system/
writer.rs1use async_trait::async_trait;
6use chrono::Utc;
7use std::path::{Path, PathBuf};
8
9use crate::types::{CheckpointId, CheckpointMeta, Layer2Result, SessionId};
10
11use super::{AtomicFileWriter, CheckpointData, CheckpointSystemTrait, ChecksumUtils};
12
13pub struct CheckpointWriter {
15 storage_path: PathBuf,
16 max_backups: usize,
17 atomic_writer: AtomicFileWriter,
18}
19
20impl CheckpointWriter {
21 pub fn new(storage_path: impl AsRef<Path>) -> Self {
23 Self {
24 storage_path: storage_path.as_ref().to_path_buf(),
25 max_backups: 3,
26 atomic_writer: AtomicFileWriter::new(),
27 }
28 }
29
30 pub fn with_max_backups(mut self, max: usize) -> Self {
32 self.max_backups = max;
33 self
34 }
35
36 fn session_dir(&self, session_id: &SessionId) -> PathBuf {
38 self.storage_path
39 .join(session_id.to_string())
40 .join("checkpoints")
41 }
42
43 fn backup_checkpoint(&self, filepath: &Path) {
45 if !filepath.exists() {
46 return;
47 }
48
49 let ext = filepath.extension().and_then(|e| e.to_str()).unwrap_or("");
50
51 let backup_path = filepath.with_extension(format!("{}.backup", ext));
52 let _ = std::fs::copy(filepath, backup_path);
53 }
54
55 fn prune_backups(&self, session_dir: &Path) {
57 let mut backups: Vec<_> = std::fs::read_dir(session_dir)
58 .ok()
59 .into_iter()
60 .flatten()
61 .filter_map(|e| e.ok())
62 .filter(|e| {
63 e.path()
64 .extension()
65 .map(|ext| ext == "backup")
66 .unwrap_or(false)
67 })
68 .collect();
69
70 backups.sort_by(|a, b| {
72 let a_time = a
73 .metadata()
74 .and_then(|m| m.modified())
75 .unwrap_or(std::time::UNIX_EPOCH);
76 let b_time = b
77 .metadata()
78 .and_then(|m| m.modified())
79 .unwrap_or(std::time::UNIX_EPOCH);
80 b_time.cmp(&a_time)
81 });
82
83 for old_backup in backups.into_iter().skip(self.max_backups) {
85 let _ = std::fs::remove_file(old_backup.path());
86 }
87 }
88
89 fn update_latest(&self, session_dir: &Path, checkpoint_path: &Path) -> Layer2Result<()> {
91 let latest_path = session_dir.join("latest.json");
92
93 #[cfg(windows)]
94 {
95 std::fs::copy(checkpoint_path, &latest_path)?;
97 }
98
99 #[cfg(not(windows))]
100 {
101 let temp_link = session_dir.join(format!(".tmp_latest_{}", uuid::Uuid::new_v4()));
103 std::os::unix::fs::symlink(checkpoint_path.file_name().unwrap(), &temp_link)?;
104 std::fs::rename(&temp_link, &latest_path)?;
105 }
106
107 Ok(())
108 }
109}
110
111#[async_trait]
112impl CheckpointSystemTrait for CheckpointWriter {
113 async fn save(&self, data: &CheckpointData) -> Layer2Result<CheckpointId> {
114 let session_dir = self.session_dir(&data.session_id);
115 std::fs::create_dir_all(&session_dir)?;
116
117 let timestamp = Utc::now().format("%Y%m%d_%H%M%S");
119 let checkpoint_id = data.checkpoint_id.clone();
120 let filename = format!("cp_{}_{}.json", timestamp, checkpoint_id);
121 let filepath = session_dir.join(&filename);
122
123 let mut json_data = serde_json::to_value(data)?;
125 json_data = ChecksumUtils::add_checksum(json_data);
126 let json_content = serde_json::to_string_pretty(&json_data)?;
127
128 let latest_path = session_dir.join("latest.json");
130 self.backup_checkpoint(&latest_path);
131
132 self.atomic_writer.write_atomic(&filepath, &json_content)?;
134
135 let _ = self.update_latest(&session_dir, &filepath);
137
138 self.prune_backups(&session_dir);
140
141 Ok(checkpoint_id)
142 }
143
144 async fn load(
145 &self,
146 session_id: &SessionId,
147 checkpoint_id: Option<&CheckpointId>,
148 ) -> Layer2Result<Option<CheckpointData>> {
149 let session_dir = self.session_dir(session_id);
150
151 if !session_dir.exists() {
152 return Ok(None);
153 }
154
155 let filepath = if let Some(id) = checkpoint_id {
157 let pattern = format!("cp_*_{}.json", id);
159 let matches: Vec<_> =
160 glob::glob(session_dir.join(&pattern).to_string_lossy().as_ref())?
161 .filter_map(|e| e.ok())
162 .collect();
163
164 if matches.is_empty() {
165 return Ok(None);
166 }
167 matches[0].clone()
168 } else {
169 let latest_path = session_dir.join("latest.json");
171 if latest_path.exists() {
172 latest_path
173 } else {
174 let mut checkpoints: Vec<_> = std::fs::read_dir(&session_dir)?
176 .filter_map(|e| e.ok())
177 .filter(|e| {
178 e.file_name().to_string_lossy().starts_with("cp_")
179 && e.path()
180 .extension()
181 .map(|ext| ext == "json")
182 .unwrap_or(false)
183 })
184 .collect();
185
186 if checkpoints.is_empty() {
187 return Ok(None);
188 }
189
190 checkpoints.sort_by(|a, b| {
191 let a_time = a
192 .metadata()
193 .and_then(|m| m.modified())
194 .unwrap_or(std::time::UNIX_EPOCH);
195 let b_time = b
196 .metadata()
197 .and_then(|m| m.modified())
198 .unwrap_or(std::time::UNIX_EPOCH);
199 b_time.cmp(&a_time)
200 });
201
202 checkpoints[0].path()
203 }
204 };
205
206 if !filepath.exists() {
207 return Ok(None);
208 }
209
210 let content = std::fs::read_to_string(&filepath)?;
212 let data: serde_json::Value = serde_json::from_str(&content)?;
213
214 let (valid, _) = ChecksumUtils::verify_checksum(&data);
215 if !valid {
216 return Err(anyhow::anyhow!("Checkpoint checksum verification failed"));
217 }
218
219 let checkpoint: CheckpointData = serde_json::from_value(data)?;
220 Ok(Some(checkpoint))
221 }
222
223 async fn list(&self, session_id: &SessionId) -> Layer2Result<Vec<CheckpointMeta>> {
224 let session_dir = self.session_dir(session_id);
225
226 if !session_dir.exists() {
227 return Ok(Vec::new());
228 }
229
230 let mut metas = Vec::new();
231
232 for entry in std::fs::read_dir(&session_dir)? {
233 let entry = entry?;
234 let path = entry.path();
235
236 if !path
237 .file_name()
238 .map(|n| n.to_string_lossy().starts_with("cp_"))
239 .unwrap_or(false)
240 {
241 continue;
242 }
243
244 if path.extension().map(|e| e != "json").unwrap_or(true) {
245 continue;
246 }
247
248 if let Ok(content) = std::fs::read_to_string(&path) {
250 if let Ok(data) = serde_json::from_str::<serde_json::Value>(&content) {
251 let (valid, _) = ChecksumUtils::verify_checksum(&data);
252
253 if valid {
254 if let Ok(meta) = serde_json::from_value::<CheckpointMeta>(data) {
255 metas.push(meta);
256 }
257 }
258 }
259 }
260 }
261
262 metas.sort_by_key(|b| std::cmp::Reverse(b.created_at));
264
265 Ok(metas)
266 }
267
268 async fn delete(
269 &self,
270 session_id: &SessionId,
271 checkpoint_id: &CheckpointId,
272 ) -> Layer2Result<bool> {
273 let session_dir = self.session_dir(session_id);
274
275 if !session_dir.exists() {
276 return Ok(false);
277 }
278
279 let pattern = format!("cp_*_{}.json", checkpoint_id);
280
281 if let Some(path) = glob::glob(session_dir.join(&pattern).to_string_lossy().as_ref())?
282 .flatten()
283 .next()
284 {
285 std::fs::remove_file(&path)?;
286 return Ok(true);
287 }
288
289 Ok(false)
290 }
291
292 fn verify(&self, path: &Path) -> Layer2Result<bool> {
293 if !path.exists() {
294 return Ok(false);
295 }
296
297 let content = std::fs::read_to_string(path)?;
298 let data: serde_json::Value = serde_json::from_str(&content)?;
299
300 let (valid, _) = ChecksumUtils::verify_checksum(&data);
301 Ok(valid)
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use tempfile::TempDir;
309
310 #[test]
311 fn test_checkpoint_writer_creation() {
312 let temp_dir = TempDir::new().unwrap();
313 let writer = CheckpointWriter::new(temp_dir.path());
314
315 assert!(writer.storage_path.exists() || writer.storage_path.parent().is_some());
316 }
317
318 #[tokio::test]
319 async fn test_save_and_load_checkpoint() {
320 let temp_dir = TempDir::new().unwrap();
321 let writer = CheckpointWriter::new(temp_dir.path());
322
323 let data = CheckpointData {
324 checkpoint_id: CheckpointId::new(),
325 session_id: SessionId::new(),
326 created_at: Utc::now(),
327 trigger: "manual".to_string(),
328 iteration: 1,
329 messages: vec![serde_json::json!({"role": "user", "content": "test"})],
330 tool_calls_pending: Vec::new(),
331 tool_results: serde_json::Value::Null,
332 tokens_used: 100,
333 cost_estimate: 0.01,
334 resume_hint: None,
335 };
336
337 let session_id = data.session_id.clone();
338 let saved_id = writer.save(&data).await.unwrap();
339
340 let loaded = writer.load(&session_id, None).await.unwrap();
341 assert!(loaded.is_some());
342
343 let loaded = loaded.unwrap();
344 assert_eq!(loaded.checkpoint_id, saved_id);
345 assert_eq!(loaded.iteration, 1);
346 }
347}