scud/attractor/
checkpoint.rs1use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::path::Path;
7
8use super::context::ContextSnapshot;
9use super::outcome::StageStatus;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Checkpoint {
14 pub timestamp: String,
16 pub current_node: String,
18 pub completed_nodes: Vec<String>,
20 pub node_retries: HashMap<String, u32>,
22 pub node_statuses: HashMap<String, StageStatus>,
24 pub context: ContextSnapshot,
26 pub log: Vec<LogEntry>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct LogEntry {
33 pub timestamp: String,
34 pub node_id: String,
35 pub message: String,
36}
37
38impl Checkpoint {
39 pub fn new(current_node: &str, context: ContextSnapshot) -> Self {
41 Self {
42 timestamp: chrono::Utc::now().to_rfc3339(),
43 current_node: current_node.to_string(),
44 completed_nodes: vec![],
45 node_retries: HashMap::new(),
46 node_statuses: HashMap::new(),
47 context,
48 log: vec![],
49 }
50 }
51
52 pub fn mark_completed(&mut self, node_id: &str, status: StageStatus) {
54 if !self.completed_nodes.contains(&node_id.to_string()) {
55 self.completed_nodes.push(node_id.to_string());
56 }
57 self.node_statuses.insert(node_id.to_string(), status);
58 }
59
60 pub fn increment_retry(&mut self, node_id: &str) -> u32 {
62 let count = self.node_retries.entry(node_id.to_string()).or_insert(0);
63 *count += 1;
64 *count
65 }
66
67 pub fn retry_count(&self, node_id: &str) -> u32 {
69 self.node_retries.get(node_id).copied().unwrap_or(0)
70 }
71
72 pub fn log(&mut self, node_id: &str, message: impl Into<String>) {
74 self.log.push(LogEntry {
75 timestamp: chrono::Utc::now().to_rfc3339(),
76 node_id: node_id.to_string(),
77 message: message.into(),
78 });
79 }
80
81 pub fn save(&self, path: &Path) -> Result<()> {
83 let json = serde_json::to_string_pretty(self)
84 .context("Failed to serialize checkpoint")?;
85 std::fs::write(path, json).context("Failed to write checkpoint file")?;
86 Ok(())
87 }
88
89 pub fn load(path: &Path) -> Result<Self> {
91 let json = std::fs::read_to_string(path).context("Failed to read checkpoint file")?;
92 let checkpoint: Self =
93 serde_json::from_str(&json).context("Failed to deserialize checkpoint")?;
94 Ok(checkpoint)
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 #[test]
103 fn test_checkpoint_roundtrip() {
104 let dir = tempfile::tempdir().unwrap();
105 let path = dir.path().join("checkpoint.json");
106
107 let mut ctx_values = HashMap::new();
108 ctx_values.insert("key".into(), serde_json::json!("value"));
109 let snapshot = ContextSnapshot::from(ctx_values);
110
111 let mut cp = Checkpoint::new("node_a", snapshot);
112 cp.mark_completed("node_a", StageStatus::Success);
113 cp.increment_retry("node_b");
114 cp.log("node_a", "Did something");
115
116 cp.save(&path).unwrap();
117 let loaded = Checkpoint::load(&path).unwrap();
118
119 assert_eq!(loaded.current_node, "node_a");
120 assert_eq!(loaded.completed_nodes, vec!["node_a"]);
121 assert_eq!(loaded.retry_count("node_b"), 1);
122 assert_eq!(loaded.log.len(), 1);
123 }
124}