Skip to main content

haystack_core/graph/
snapshot.rs

1//! HLSS v1 binary snapshot format — write/read EntityGraph state to disk.
2//!
3//! Format: [Header][Zstd-compressed Zinc body][CRC32 footer]
4//! Header: "HLSS" (4 bytes) + format_version (u16 LE) + entity_count (u32 LE)
5//!         + timestamp (i64 LE, Unix nanos) + graph_version (u64 LE) = 26 bytes
6//! Body: Zstd-compressed Zinc-encoded grid of all entities
7//! Footer: CRC32 (u32 LE) over header + compressed body
8
9use std::collections::BTreeSet;
10use std::path::{Path, PathBuf};
11
12use crate::codecs::Codec;
13use crate::codecs::zinc::ZincCodec;
14use crate::data::{HCol, HDict, HGrid};
15use crate::graph::shared::SharedGraph;
16
17/// Metadata from a loaded snapshot.
18#[derive(Debug, Clone)]
19pub struct SnapshotMeta {
20    pub format_version: u16,
21    pub entity_count: u32,
22    pub timestamp: i64,
23    pub graph_version: u64,
24    pub path: PathBuf,
25}
26
27/// Errors during snapshot operations.
28#[derive(Debug, thiserror::Error)]
29pub enum SnapshotError {
30    #[error("I/O error: {0}")]
31    Io(#[from] std::io::Error),
32    #[error("invalid magic bytes")]
33    InvalidMagic,
34    #[error("unsupported format version: {0}")]
35    UnsupportedVersion(u16),
36    #[error("CRC32 mismatch: expected {expected:#010x}, got {actual:#010x}")]
37    CrcMismatch { expected: u32, actual: u32 },
38    #[error("decompression error: {0}")]
39    Decompression(String),
40    #[error("codec error: {0}")]
41    Codec(String),
42}
43
44/// Writes snapshots of a [`SharedGraph`] to disk in HLSS format.
45pub struct SnapshotWriter {
46    dir: PathBuf,
47    max_snapshots: usize,
48    compression_level: i32,
49}
50
51impl SnapshotWriter {
52    pub fn new(dir: PathBuf, max_snapshots: usize) -> Self {
53        Self {
54            dir,
55            max_snapshots,
56            compression_level: 3,
57        }
58    }
59
60    pub fn with_compression(mut self, level: i32) -> Self {
61        self.compression_level = level;
62        self
63    }
64
65    /// Write a snapshot of the graph. Returns path to the snapshot file.
66    /// Uses atomic write: write to temp file, then rename.
67    pub fn write(&self, graph: &SharedGraph) -> Result<PathBuf, SnapshotError> {
68        std::fs::create_dir_all(&self.dir)?;
69
70        let (grid, version) = graph.read(|g| {
71            let entities = g.all();
72            let version = g.version();
73
74            // Collect all unique tag names across entities for columns.
75            let mut col_names = BTreeSet::new();
76            for entity in &entities {
77                for (key, _) in entity.iter() {
78                    col_names.insert(key.to_owned());
79                }
80            }
81            let cols: Vec<HCol> = col_names.iter().map(|n| HCol::new(n)).collect();
82            let rows: Vec<HDict> = entities.into_iter().cloned().collect();
83            let grid = HGrid::from_parts(HDict::new(), cols, rows);
84            (grid, version)
85        });
86
87        // Encode to Zinc
88        let zinc = ZincCodec;
89        let zinc_str = zinc
90            .encode_grid(&grid)
91            .map_err(|e| SnapshotError::Codec(e.to_string()))?;
92
93        // Compress
94        let compressed = zstd::encode_all(zinc_str.as_bytes(), self.compression_level)
95            .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
96
97        // Build header
98        let timestamp = std::time::SystemTime::now()
99            .duration_since(std::time::UNIX_EPOCH)
100            .unwrap_or_default()
101            .as_nanos() as i64;
102        let entity_count = grid.rows.len() as u32;
103
104        let mut buf = Vec::new();
105        buf.extend_from_slice(b"HLSS");
106        buf.extend_from_slice(&1u16.to_le_bytes()); // format version
107        buf.extend_from_slice(&entity_count.to_le_bytes());
108        buf.extend_from_slice(&timestamp.to_le_bytes());
109        buf.extend_from_slice(&version.to_le_bytes());
110        buf.extend_from_slice(&compressed);
111
112        // CRC32 over header + body
113        let crc = crc32fast::hash(&buf);
114        buf.extend_from_slice(&crc.to_le_bytes());
115
116        // Atomic write
117        let filename = format!("snapshot-{version}.hlss");
118        let final_path = self.dir.join(&filename);
119        let tmp_path = self.dir.join(format!(".{filename}.tmp"));
120        std::fs::write(&tmp_path, &buf)?;
121        std::fs::rename(&tmp_path, &final_path)?;
122
123        // Rotate old snapshots
124        self.rotate()?;
125
126        Ok(final_path)
127    }
128
129    /// Remove old snapshots beyond max_snapshots. Returns number removed.
130    pub fn rotate(&self) -> Result<usize, SnapshotError> {
131        let mut snapshots = Self::list_snapshots(&self.dir)?;
132        if snapshots.len() <= self.max_snapshots {
133            return Ok(0);
134        }
135        // Sort by filename (has version number) — oldest first
136        snapshots.sort();
137        let to_remove = snapshots.len() - self.max_snapshots;
138        for path in &snapshots[..to_remove] {
139            std::fs::remove_file(path)?;
140        }
141        Ok(to_remove)
142    }
143
144    fn list_snapshots(dir: &Path) -> Result<Vec<PathBuf>, SnapshotError> {
145        if !dir.exists() {
146            return Ok(Vec::new());
147        }
148        Ok(std::fs::read_dir(dir)?
149            .filter_map(|e| e.ok())
150            .filter(|e| e.path().extension().is_some_and(|ext| ext == "hlss"))
151            .filter(|e| !e.file_name().to_str().unwrap_or("").starts_with('.'))
152            .map(|e| e.path())
153            .collect())
154    }
155}
156
157/// Reads and restores snapshots from disk in HLSS format.
158pub struct SnapshotReader;
159
160impl SnapshotReader {
161    /// Find the latest snapshot in a directory.
162    pub fn find_latest(dir: &Path) -> Result<Option<PathBuf>, SnapshotError> {
163        let mut snapshots = SnapshotWriter::list_snapshots(dir)?;
164        if snapshots.is_empty() {
165            return Ok(None);
166        }
167        snapshots.sort();
168        Ok(Some(snapshots.pop().unwrap()))
169    }
170
171    /// Load a snapshot file and import entities into the graph.
172    pub fn load(path: &Path, graph: &SharedGraph) -> Result<SnapshotMeta, SnapshotError> {
173        let data = std::fs::read(path)?;
174        Self::load_from_bytes(&data, path, graph)
175    }
176
177    /// Load from raw bytes (useful for testing).
178    pub fn load_from_bytes(
179        data: &[u8],
180        path: &Path,
181        graph: &SharedGraph,
182    ) -> Result<SnapshotMeta, SnapshotError> {
183        // Need at least header (26) + CRC (4) = 30 bytes
184        if data.len() < 30 {
185            return Err(SnapshotError::InvalidMagic);
186        }
187
188        // Validate magic
189        if &data[0..4] != b"HLSS" {
190            return Err(SnapshotError::InvalidMagic);
191        }
192
193        // Parse header
194        let format_version = u16::from_le_bytes([data[4], data[5]]);
195        if format_version != 1 {
196            return Err(SnapshotError::UnsupportedVersion(format_version));
197        }
198        let entity_count = u32::from_le_bytes(data[6..10].try_into().unwrap());
199        let timestamp = i64::from_le_bytes(data[10..18].try_into().unwrap());
200        let graph_version = u64::from_le_bytes(data[18..26].try_into().unwrap());
201
202        // Validate CRC32
203        let crc_offset = data.len() - 4;
204        let expected_crc = u32::from_le_bytes(data[crc_offset..].try_into().unwrap());
205        let actual_crc = crc32fast::hash(&data[..crc_offset]);
206        if expected_crc != actual_crc {
207            return Err(SnapshotError::CrcMismatch {
208                expected: expected_crc,
209                actual: actual_crc,
210            });
211        }
212
213        // Decompress
214        let compressed = &data[26..crc_offset];
215        const MAX_DECOMPRESSED_SIZE: usize = 1024 * 1024 * 1024; // 1 GB
216        let decompressed = zstd::decode_all(compressed)
217            .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
218        if decompressed.len() > MAX_DECOMPRESSED_SIZE {
219            return Err(SnapshotError::Decompression(format!(
220                "decompressed data too large: {} bytes (max {})",
221                decompressed.len(),
222                MAX_DECOMPRESSED_SIZE
223            )));
224        }
225
226        // Decode Zinc
227        let zinc_str =
228            std::str::from_utf8(&decompressed).map_err(|e| SnapshotError::Codec(e.to_string()))?;
229        let zinc = ZincCodec;
230        let grid = zinc
231            .decode_grid(zinc_str)
232            .map_err(|e| SnapshotError::Codec(e.to_string()))?;
233
234        // Import entities into graph — each row with an "id" tag is added.
235        graph.write(|g| {
236            for row in &grid.rows {
237                if let Some(id_ref) = row.id() {
238                    g.add(row.clone()).ok(); // ignore duplicates on restore
239                    let _ = id_ref; // id_ref used only for the guard
240                }
241            }
242        });
243
244        Ok(SnapshotMeta {
245            format_version,
246            entity_count,
247            timestamp,
248            graph_version,
249            path: path.to_path_buf(),
250        })
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use crate::graph::EntityGraph;
258    use crate::kinds::{HRef, Kind};
259
260    fn make_site(id: &str) -> HDict {
261        let mut d = HDict::new();
262        d.set("id", Kind::Ref(HRef::from_val(id)));
263        d.set("site", Kind::Marker);
264        d.set("dis", Kind::Str(format!("Site {id}")));
265        d
266    }
267
268    fn populated_graph() -> SharedGraph {
269        let sg = SharedGraph::new(EntityGraph::new());
270        sg.add(make_site("site-1")).unwrap();
271        sg.add(make_site("site-2")).unwrap();
272        sg.add(make_site("site-3")).unwrap();
273        sg
274    }
275
276    #[test]
277    fn write_and_read_roundtrip() {
278        let dir = tempfile::tempdir().unwrap();
279        let graph = populated_graph();
280
281        let writer = SnapshotWriter::new(dir.path().to_path_buf(), 5);
282        let snap_path = writer.write(&graph).unwrap();
283        assert!(snap_path.exists());
284
285        // Load into a fresh graph
286        let graph2 = SharedGraph::new(EntityGraph::new());
287        let meta = SnapshotReader::load(&snap_path, &graph2).unwrap();
288
289        assert_eq!(meta.format_version, 1);
290        assert_eq!(meta.entity_count, 3);
291        assert_eq!(meta.graph_version, 3);
292        assert_eq!(graph2.len(), 3);
293        assert!(graph2.contains("site-1"));
294        assert!(graph2.contains("site-2"));
295        assert!(graph2.contains("site-3"));
296    }
297
298    #[test]
299    fn find_latest_returns_most_recent() {
300        let dir = tempfile::tempdir().unwrap();
301        let graph = SharedGraph::new(EntityGraph::new());
302        let writer = SnapshotWriter::new(dir.path().to_path_buf(), 5);
303
304        graph.add(make_site("site-1")).unwrap();
305        let _path1 = writer.write(&graph).unwrap();
306
307        graph.add(make_site("site-2")).unwrap();
308        let path2 = writer.write(&graph).unwrap();
309
310        let latest = SnapshotReader::find_latest(dir.path()).unwrap().unwrap();
311        assert_eq!(latest, path2);
312    }
313
314    #[test]
315    fn rotate_removes_old_snapshots() {
316        let dir = tempfile::tempdir().unwrap();
317        let graph = SharedGraph::new(EntityGraph::new());
318        let writer = SnapshotWriter::new(dir.path().to_path_buf(), 2);
319
320        // Create 4 snapshots
321        for i in 0..4 {
322            graph.add(make_site(&format!("s-{i}"))).unwrap();
323            writer.write(&graph).unwrap();
324        }
325
326        let remaining = SnapshotWriter::list_snapshots(dir.path()).unwrap();
327        assert_eq!(remaining.len(), 2);
328    }
329
330    #[test]
331    fn corrupt_crc_detected() {
332        let dir = tempfile::tempdir().unwrap();
333        let graph = populated_graph();
334
335        let writer = SnapshotWriter::new(dir.path().to_path_buf(), 5);
336        let snap_path = writer.write(&graph).unwrap();
337
338        let mut data = std::fs::read(&snap_path).unwrap();
339        // Corrupt one byte in the compressed body
340        if data.len() > 30 {
341            data[28] ^= 0xFF;
342        }
343
344        let graph2 = SharedGraph::new(EntityGraph::new());
345        let result = SnapshotReader::load_from_bytes(&data, &snap_path, &graph2);
346        assert!(matches!(result, Err(SnapshotError::CrcMismatch { .. })));
347    }
348
349    #[test]
350    fn invalid_magic_rejected() {
351        let data = b"NOPE_this_is_not_a_snapshot_at_all";
352        let graph = SharedGraph::new(EntityGraph::new());
353        let result = SnapshotReader::load_from_bytes(data, Path::new("bad.hlss"), &graph);
354        assert!(matches!(result, Err(SnapshotError::InvalidMagic)));
355    }
356
357    #[test]
358    fn empty_graph_produces_valid_snapshot() {
359        let dir = tempfile::tempdir().unwrap();
360        let graph = SharedGraph::new(EntityGraph::new());
361
362        let writer = SnapshotWriter::new(dir.path().to_path_buf(), 5);
363        let snap_path = writer.write(&graph).unwrap();
364
365        let graph2 = SharedGraph::new(EntityGraph::new());
366        let meta = SnapshotReader::load(&snap_path, &graph2).unwrap();
367        assert_eq!(meta.entity_count, 0);
368        assert_eq!(graph2.len(), 0);
369    }
370
371    #[test]
372    fn find_latest_on_empty_dir() {
373        let dir = tempfile::tempdir().unwrap();
374        let result = SnapshotReader::find_latest(dir.path()).unwrap();
375        assert!(result.is_none());
376    }
377
378    #[test]
379    fn too_short_data_rejected() {
380        let data = b"HLSS_short";
381        let graph = SharedGraph::new(EntityGraph::new());
382        let result = SnapshotReader::load_from_bytes(data, Path::new("x.hlss"), &graph);
383        assert!(matches!(result, Err(SnapshotError::InvalidMagic)));
384    }
385
386    #[test]
387    fn unsupported_version_rejected() {
388        // Build a minimal valid-looking buffer with version 99
389        let mut data = Vec::new();
390        data.extend_from_slice(b"HLSS");
391        data.extend_from_slice(&99u16.to_le_bytes());
392        data.extend_from_slice(&0u32.to_le_bytes());
393        data.extend_from_slice(&0i64.to_le_bytes());
394        data.extend_from_slice(&0u64.to_le_bytes());
395        // Need some body bytes so len >= 30
396        data.extend_from_slice(&[0u8; 4]); // fake body
397        let crc = crc32fast::hash(&data);
398        data.extend_from_slice(&crc.to_le_bytes());
399
400        let graph = SharedGraph::new(EntityGraph::new());
401        let result = SnapshotReader::load_from_bytes(&data, Path::new("x.hlss"), &graph);
402        assert!(matches!(result, Err(SnapshotError::UnsupportedVersion(99))));
403    }
404}