Skip to main content

sh_layer2/checkpoint_system/
writer.rs

1//! # Checkpoint Writer
2//!
3//! 检查点写入器实现。
4
5use async_trait::async_trait;
6use chrono::Utc;
7use std::path::{Path, PathBuf};
8
9use crate::types::{CheckpointId, CheckpointMeta, Layer2Result, SessionId};
10
11use super::{AtomicFileWriter, CheckpointData, CheckpointSystemTrait, ChecksumUtils};
12
13/// 检查点写入器
14pub struct CheckpointWriter {
15    storage_path: PathBuf,
16    max_backups: usize,
17    atomic_writer: AtomicFileWriter,
18}
19
20impl CheckpointWriter {
21    /// 创建新的检查点写入器
22    pub fn new(storage_path: impl AsRef<Path>) -> Self {
23        Self {
24            storage_path: storage_path.as_ref().to_path_buf(),
25            max_backups: 3,
26            atomic_writer: AtomicFileWriter::new(),
27        }
28    }
29
30    /// 配置最大备份数
31    pub fn with_max_backups(mut self, max: usize) -> Self {
32        self.max_backups = max;
33        self
34    }
35
36    /// 获取会话目录
37    fn session_dir(&self, session_id: &SessionId) -> PathBuf {
38        self.storage_path
39            .join(session_id.to_string())
40            .join("checkpoints")
41    }
42
43    /// 备份现有检查点
44    fn backup_checkpoint(&self, filepath: &Path) {
45        if !filepath.exists() {
46            return;
47        }
48
49        let ext = filepath.extension().and_then(|e| e.to_str()).unwrap_or("");
50
51        let backup_path = filepath.with_extension(format!("{}.backup", ext));
52        let _ = std::fs::copy(filepath, backup_path);
53    }
54
55    /// 清理旧备份
56    fn prune_backups(&self, session_dir: &Path) {
57        let mut backups: Vec<_> = std::fs::read_dir(session_dir)
58            .ok()
59            .into_iter()
60            .flatten()
61            .filter_map(|e| e.ok())
62            .filter(|e| {
63                e.path()
64                    .extension()
65                    .map(|ext| ext == "backup")
66                    .unwrap_or(false)
67            })
68            .collect();
69
70        // 按修改时间排序(最新在前)
71        backups.sort_by(|a, b| {
72            let a_time = a
73                .metadata()
74                .and_then(|m| m.modified())
75                .unwrap_or(std::time::UNIX_EPOCH);
76            let b_time = b
77                .metadata()
78                .and_then(|m| m.modified())
79                .unwrap_or(std::time::UNIX_EPOCH);
80            b_time.cmp(&a_time)
81        });
82
83        // 删除超过限制的备份
84        for old_backup in backups.into_iter().skip(self.max_backups) {
85            let _ = std::fs::remove_file(old_backup.path());
86        }
87    }
88
89    /// 更新 latest 引用
90    fn update_latest(&self, session_dir: &Path, checkpoint_path: &Path) -> Layer2Result<()> {
91        let latest_path = session_dir.join("latest.json");
92
93        #[cfg(windows)]
94        {
95            // Windows: 使用复制(符号链接需要管理员权限)
96            std::fs::copy(checkpoint_path, &latest_path)?;
97        }
98
99        #[cfg(not(windows))]
100        {
101            // Unix: 使用原子符号链接
102            let temp_link = session_dir.join(format!(".tmp_latest_{}", uuid::Uuid::new_v4()));
103            std::os::unix::fs::symlink(checkpoint_path.file_name().unwrap(), &temp_link)?;
104            std::fs::rename(&temp_link, &latest_path)?;
105        }
106
107        Ok(())
108    }
109}
110
111#[async_trait]
112impl CheckpointSystemTrait for CheckpointWriter {
113    async fn save(&self, data: &CheckpointData) -> Layer2Result<CheckpointId> {
114        let session_dir = self.session_dir(&data.session_id);
115        std::fs::create_dir_all(&session_dir)?;
116
117        // 生成文件名
118        let timestamp = Utc::now().format("%Y%m%d_%H%M%S");
119        let checkpoint_id = data.checkpoint_id.clone();
120        let filename = format!("cp_{}_{}.json", timestamp, checkpoint_id);
121        let filepath = session_dir.join(&filename);
122
123        // 序列化数据
124        let mut json_data = serde_json::to_value(data)?;
125        json_data = ChecksumUtils::add_checksum(json_data);
126        let json_content = serde_json::to_string_pretty(&json_data)?;
127
128        // 备份现有 latest
129        let latest_path = session_dir.join("latest.json");
130        self.backup_checkpoint(&latest_path);
131
132        // 原子写入
133        self.atomic_writer.write_atomic(&filepath, &json_content)?;
134
135        // 更新 latest 引用
136        let _ = self.update_latest(&session_dir, &filepath);
137
138        // 清理旧备份
139        self.prune_backups(&session_dir);
140
141        Ok(checkpoint_id)
142    }
143
144    async fn load(
145        &self,
146        session_id: &SessionId,
147        checkpoint_id: Option<&CheckpointId>,
148    ) -> Layer2Result<Option<CheckpointData>> {
149        let session_dir = self.session_dir(session_id);
150
151        if !session_dir.exists() {
152            return Ok(None);
153        }
154
155        // 确定要加载的检查点
156        let filepath = if let Some(id) = checkpoint_id {
157            // 查找指定的检查点
158            let pattern = format!("cp_*_{}.json", id);
159            let matches: Vec<_> =
160                glob::glob(session_dir.join(&pattern).to_string_lossy().as_ref())?
161                    .filter_map(|e| e.ok())
162                    .collect();
163
164            if matches.is_empty() {
165                return Ok(None);
166            }
167            matches[0].clone()
168        } else {
169            // 加载最新的检查点
170            let latest_path = session_dir.join("latest.json");
171            if latest_path.exists() {
172                latest_path
173            } else {
174                // 查找最新的检查点
175                let mut checkpoints: Vec<_> = std::fs::read_dir(&session_dir)?
176                    .filter_map(|e| e.ok())
177                    .filter(|e| {
178                        e.file_name().to_string_lossy().starts_with("cp_")
179                            && e.path()
180                                .extension()
181                                .map(|ext| ext == "json")
182                                .unwrap_or(false)
183                    })
184                    .collect();
185
186                if checkpoints.is_empty() {
187                    return Ok(None);
188                }
189
190                checkpoints.sort_by(|a, b| {
191                    let a_time = a
192                        .metadata()
193                        .and_then(|m| m.modified())
194                        .unwrap_or(std::time::UNIX_EPOCH);
195                    let b_time = b
196                        .metadata()
197                        .and_then(|m| m.modified())
198                        .unwrap_or(std::time::UNIX_EPOCH);
199                    b_time.cmp(&a_time)
200                });
201
202                checkpoints[0].path()
203            }
204        };
205
206        if !filepath.exists() {
207            return Ok(None);
208        }
209
210        // 读取并验证
211        let content = std::fs::read_to_string(&filepath)?;
212        let data: serde_json::Value = serde_json::from_str(&content)?;
213
214        let (valid, _) = ChecksumUtils::verify_checksum(&data);
215        if !valid {
216            return Err(anyhow::anyhow!("Checkpoint checksum verification failed"));
217        }
218
219        let checkpoint: CheckpointData = serde_json::from_value(data)?;
220        Ok(Some(checkpoint))
221    }
222
223    async fn list(&self, session_id: &SessionId) -> Layer2Result<Vec<CheckpointMeta>> {
224        let session_dir = self.session_dir(session_id);
225
226        if !session_dir.exists() {
227            return Ok(Vec::new());
228        }
229
230        let mut metas = Vec::new();
231
232        for entry in std::fs::read_dir(&session_dir)? {
233            let entry = entry?;
234            let path = entry.path();
235
236            if !path
237                .file_name()
238                .map(|n| n.to_string_lossy().starts_with("cp_"))
239                .unwrap_or(false)
240            {
241                continue;
242            }
243
244            if path.extension().map(|e| e != "json").unwrap_or(true) {
245                continue;
246            }
247
248            // 读取并验证
249            if let Ok(content) = std::fs::read_to_string(&path) {
250                if let Ok(data) = serde_json::from_str::<serde_json::Value>(&content) {
251                    let (valid, _) = ChecksumUtils::verify_checksum(&data);
252
253                    if valid {
254                        if let Ok(meta) = serde_json::from_value::<CheckpointMeta>(data) {
255                            metas.push(meta);
256                        }
257                    }
258                }
259            }
260        }
261
262        // 按创建时间排序(最新在前)
263        metas.sort_by_key(|b| std::cmp::Reverse(b.created_at));
264
265        Ok(metas)
266    }
267
268    async fn delete(
269        &self,
270        session_id: &SessionId,
271        checkpoint_id: &CheckpointId,
272    ) -> Layer2Result<bool> {
273        let session_dir = self.session_dir(session_id);
274
275        if !session_dir.exists() {
276            return Ok(false);
277        }
278
279        let pattern = format!("cp_*_{}.json", checkpoint_id);
280
281        if let Some(path) = glob::glob(session_dir.join(&pattern).to_string_lossy().as_ref())?
282            .flatten()
283            .next()
284        {
285            std::fs::remove_file(&path)?;
286            return Ok(true);
287        }
288
289        Ok(false)
290    }
291
292    fn verify(&self, path: &Path) -> Layer2Result<bool> {
293        if !path.exists() {
294            return Ok(false);
295        }
296
297        let content = std::fs::read_to_string(path)?;
298        let data: serde_json::Value = serde_json::from_str(&content)?;
299
300        let (valid, _) = ChecksumUtils::verify_checksum(&data);
301        Ok(valid)
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use tempfile::TempDir;
309
310    #[test]
311    fn test_checkpoint_writer_creation() {
312        let temp_dir = TempDir::new().unwrap();
313        let writer = CheckpointWriter::new(temp_dir.path());
314
315        assert!(writer.storage_path.exists() || writer.storage_path.parent().is_some());
316    }
317
318    #[tokio::test]
319    async fn test_save_and_load_checkpoint() {
320        let temp_dir = TempDir::new().unwrap();
321        let writer = CheckpointWriter::new(temp_dir.path());
322
323        let data = CheckpointData {
324            checkpoint_id: CheckpointId::new(),
325            session_id: SessionId::new(),
326            created_at: Utc::now(),
327            trigger: "manual".to_string(),
328            iteration: 1,
329            messages: vec![serde_json::json!({"role": "user", "content": "test"})],
330            tool_calls_pending: Vec::new(),
331            tool_results: serde_json::Value::Null,
332            tokens_used: 100,
333            cost_estimate: 0.01,
334            resume_hint: None,
335        };
336
337        let session_id = data.session_id.clone();
338        let saved_id = writer.save(&data).await.unwrap();
339
340        let loaded = writer.load(&session_id, None).await.unwrap();
341        assert!(loaded.is_some());
342
343        let loaded = loaded.unwrap();
344        assert_eq!(loaded.checkpoint_id, saved_id);
345        assert_eq!(loaded.iteration, 1);
346    }
347}