1use std::collections::BTreeSet;
10use std::path::{Path, PathBuf};
11
12use crate::codecs::Codec;
13use crate::codecs::zinc::ZincCodec;
14use crate::data::{HCol, HDict, HGrid};
15use crate::graph::shared::SharedGraph;
16
17#[derive(Debug, Clone)]
19pub struct SnapshotMeta {
20 pub format_version: u16,
21 pub entity_count: u32,
22 pub timestamp: i64,
23 pub graph_version: u64,
24 pub path: PathBuf,
25}
26
27#[derive(Debug, thiserror::Error)]
29pub enum SnapshotError {
30 #[error("I/O error: {0}")]
31 Io(#[from] std::io::Error),
32 #[error("invalid magic bytes")]
33 InvalidMagic,
34 #[error("unsupported format version: {0}")]
35 UnsupportedVersion(u16),
36 #[error("CRC32 mismatch: expected {expected:#010x}, got {actual:#010x}")]
37 CrcMismatch { expected: u32, actual: u32 },
38 #[error("decompression error: {0}")]
39 Decompression(String),
40 #[error("codec error: {0}")]
41 Codec(String),
42}
43
44pub struct SnapshotWriter {
46 dir: PathBuf,
47 max_snapshots: usize,
48 compression_level: i32,
49}
50
51impl SnapshotWriter {
52 pub fn new(dir: PathBuf, max_snapshots: usize) -> Self {
53 Self {
54 dir,
55 max_snapshots,
56 compression_level: 3,
57 }
58 }
59
60 pub fn with_compression(mut self, level: i32) -> Self {
61 self.compression_level = level;
62 self
63 }
64
65 pub fn write(&self, graph: &SharedGraph) -> Result<PathBuf, SnapshotError> {
68 std::fs::create_dir_all(&self.dir)?;
69
70 let (grid, version) = graph.read(|g| {
71 let entities = g.all();
72 let version = g.version();
73
74 let mut col_names = BTreeSet::new();
76 for entity in &entities {
77 for (key, _) in entity.iter() {
78 col_names.insert(key.to_owned());
79 }
80 }
81 let cols: Vec<HCol> = col_names.iter().map(|n| HCol::new(n)).collect();
82 let rows: Vec<HDict> = entities.into_iter().cloned().collect();
83 let grid = HGrid::from_parts(HDict::new(), cols, rows);
84 (grid, version)
85 });
86
87 let zinc = ZincCodec;
89 let zinc_str = zinc
90 .encode_grid(&grid)
91 .map_err(|e| SnapshotError::Codec(e.to_string()))?;
92
93 let compressed = zstd::encode_all(zinc_str.as_bytes(), self.compression_level)
95 .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
96
97 let timestamp = std::time::SystemTime::now()
99 .duration_since(std::time::UNIX_EPOCH)
100 .unwrap_or_default()
101 .as_nanos() as i64;
102 let entity_count = grid.rows.len() as u32;
103
104 let mut buf = Vec::new();
105 buf.extend_from_slice(b"HLSS");
106 buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&entity_count.to_le_bytes());
108 buf.extend_from_slice(×tamp.to_le_bytes());
109 buf.extend_from_slice(&version.to_le_bytes());
110 buf.extend_from_slice(&compressed);
111
112 let crc = crc32fast::hash(&buf);
114 buf.extend_from_slice(&crc.to_le_bytes());
115
116 let filename = format!("snapshot-{version}.hlss");
118 let final_path = self.dir.join(&filename);
119 let tmp_path = self.dir.join(format!(".{filename}.tmp"));
120 std::fs::write(&tmp_path, &buf)?;
121 std::fs::rename(&tmp_path, &final_path)?;
122
123 self.rotate()?;
125
126 Ok(final_path)
127 }
128
129 pub fn rotate(&self) -> Result<usize, SnapshotError> {
131 let mut snapshots = Self::list_snapshots(&self.dir)?;
132 if snapshots.len() <= self.max_snapshots {
133 return Ok(0);
134 }
135 snapshots.sort();
137 let to_remove = snapshots.len() - self.max_snapshots;
138 for path in &snapshots[..to_remove] {
139 std::fs::remove_file(path)?;
140 }
141 Ok(to_remove)
142 }
143
144 fn list_snapshots(dir: &Path) -> Result<Vec<PathBuf>, SnapshotError> {
145 if !dir.exists() {
146 return Ok(Vec::new());
147 }
148 Ok(std::fs::read_dir(dir)?
149 .filter_map(|e| e.ok())
150 .filter(|e| e.path().extension().is_some_and(|ext| ext == "hlss"))
151 .filter(|e| !e.file_name().to_str().unwrap_or("").starts_with('.'))
152 .map(|e| e.path())
153 .collect())
154 }
155}
156
157pub struct SnapshotReader;
159
160impl SnapshotReader {
161 pub fn find_latest(dir: &Path) -> Result<Option<PathBuf>, SnapshotError> {
163 let mut snapshots = SnapshotWriter::list_snapshots(dir)?;
164 if snapshots.is_empty() {
165 return Ok(None);
166 }
167 snapshots.sort();
168 Ok(Some(snapshots.pop().unwrap()))
169 }
170
171 pub fn load(path: &Path, graph: &SharedGraph) -> Result<SnapshotMeta, SnapshotError> {
173 let data = std::fs::read(path)?;
174 Self::load_from_bytes(&data, path, graph)
175 }
176
177 pub fn load_from_bytes(
179 data: &[u8],
180 path: &Path,
181 graph: &SharedGraph,
182 ) -> Result<SnapshotMeta, SnapshotError> {
183 if data.len() < 30 {
185 return Err(SnapshotError::InvalidMagic);
186 }
187
188 if &data[0..4] != b"HLSS" {
190 return Err(SnapshotError::InvalidMagic);
191 }
192
193 let format_version = u16::from_le_bytes([data[4], data[5]]);
195 if format_version != 1 {
196 return Err(SnapshotError::UnsupportedVersion(format_version));
197 }
198 let entity_count = u32::from_le_bytes(data[6..10].try_into().unwrap());
199 let timestamp = i64::from_le_bytes(data[10..18].try_into().unwrap());
200 let graph_version = u64::from_le_bytes(data[18..26].try_into().unwrap());
201
202 let crc_offset = data.len() - 4;
204 let expected_crc = u32::from_le_bytes(data[crc_offset..].try_into().unwrap());
205 let actual_crc = crc32fast::hash(&data[..crc_offset]);
206 if expected_crc != actual_crc {
207 return Err(SnapshotError::CrcMismatch {
208 expected: expected_crc,
209 actual: actual_crc,
210 });
211 }
212
213 let compressed = &data[26..crc_offset];
215 const MAX_DECOMPRESSED_SIZE: usize = 1024 * 1024 * 1024; let decompressed = zstd::decode_all(compressed)
217 .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
218 if decompressed.len() > MAX_DECOMPRESSED_SIZE {
219 return Err(SnapshotError::Decompression(format!(
220 "decompressed data too large: {} bytes (max {})",
221 decompressed.len(),
222 MAX_DECOMPRESSED_SIZE
223 )));
224 }
225
226 let zinc_str =
228 std::str::from_utf8(&decompressed).map_err(|e| SnapshotError::Codec(e.to_string()))?;
229 let zinc = ZincCodec;
230 let grid = zinc
231 .decode_grid(zinc_str)
232 .map_err(|e| SnapshotError::Codec(e.to_string()))?;
233
234 graph.write(|g| {
236 for row in &grid.rows {
237 if let Some(id_ref) = row.id() {
238 g.add(row.clone()).ok(); let _ = id_ref; }
241 }
242 });
243
244 Ok(SnapshotMeta {
245 format_version,
246 entity_count,
247 timestamp,
248 graph_version,
249 path: path.to_path_buf(),
250 })
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use crate::graph::EntityGraph;
258 use crate::kinds::{HRef, Kind};
259
260 fn make_site(id: &str) -> HDict {
261 let mut d = HDict::new();
262 d.set("id", Kind::Ref(HRef::from_val(id)));
263 d.set("site", Kind::Marker);
264 d.set("dis", Kind::Str(format!("Site {id}")));
265 d
266 }
267
268 fn populated_graph() -> SharedGraph {
269 let sg = SharedGraph::new(EntityGraph::new());
270 sg.add(make_site("site-1")).unwrap();
271 sg.add(make_site("site-2")).unwrap();
272 sg.add(make_site("site-3")).unwrap();
273 sg
274 }
275
276 #[test]
277 fn write_and_read_roundtrip() {
278 let dir = tempfile::tempdir().unwrap();
279 let graph = populated_graph();
280
281 let writer = SnapshotWriter::new(dir.path().to_path_buf(), 5);
282 let snap_path = writer.write(&graph).unwrap();
283 assert!(snap_path.exists());
284
285 let graph2 = SharedGraph::new(EntityGraph::new());
287 let meta = SnapshotReader::load(&snap_path, &graph2).unwrap();
288
289 assert_eq!(meta.format_version, 1);
290 assert_eq!(meta.entity_count, 3);
291 assert_eq!(meta.graph_version, 3);
292 assert_eq!(graph2.len(), 3);
293 assert!(graph2.contains("site-1"));
294 assert!(graph2.contains("site-2"));
295 assert!(graph2.contains("site-3"));
296 }
297
298 #[test]
299 fn find_latest_returns_most_recent() {
300 let dir = tempfile::tempdir().unwrap();
301 let graph = SharedGraph::new(EntityGraph::new());
302 let writer = SnapshotWriter::new(dir.path().to_path_buf(), 5);
303
304 graph.add(make_site("site-1")).unwrap();
305 let _path1 = writer.write(&graph).unwrap();
306
307 graph.add(make_site("site-2")).unwrap();
308 let path2 = writer.write(&graph).unwrap();
309
310 let latest = SnapshotReader::find_latest(dir.path()).unwrap().unwrap();
311 assert_eq!(latest, path2);
312 }
313
314 #[test]
315 fn rotate_removes_old_snapshots() {
316 let dir = tempfile::tempdir().unwrap();
317 let graph = SharedGraph::new(EntityGraph::new());
318 let writer = SnapshotWriter::new(dir.path().to_path_buf(), 2);
319
320 for i in 0..4 {
322 graph.add(make_site(&format!("s-{i}"))).unwrap();
323 writer.write(&graph).unwrap();
324 }
325
326 let remaining = SnapshotWriter::list_snapshots(dir.path()).unwrap();
327 assert_eq!(remaining.len(), 2);
328 }
329
330 #[test]
331 fn corrupt_crc_detected() {
332 let dir = tempfile::tempdir().unwrap();
333 let graph = populated_graph();
334
335 let writer = SnapshotWriter::new(dir.path().to_path_buf(), 5);
336 let snap_path = writer.write(&graph).unwrap();
337
338 let mut data = std::fs::read(&snap_path).unwrap();
339 if data.len() > 30 {
341 data[28] ^= 0xFF;
342 }
343
344 let graph2 = SharedGraph::new(EntityGraph::new());
345 let result = SnapshotReader::load_from_bytes(&data, &snap_path, &graph2);
346 assert!(matches!(result, Err(SnapshotError::CrcMismatch { .. })));
347 }
348
349 #[test]
350 fn invalid_magic_rejected() {
351 let data = b"NOPE_this_is_not_a_snapshot_at_all";
352 let graph = SharedGraph::new(EntityGraph::new());
353 let result = SnapshotReader::load_from_bytes(data, Path::new("bad.hlss"), &graph);
354 assert!(matches!(result, Err(SnapshotError::InvalidMagic)));
355 }
356
357 #[test]
358 fn empty_graph_produces_valid_snapshot() {
359 let dir = tempfile::tempdir().unwrap();
360 let graph = SharedGraph::new(EntityGraph::new());
361
362 let writer = SnapshotWriter::new(dir.path().to_path_buf(), 5);
363 let snap_path = writer.write(&graph).unwrap();
364
365 let graph2 = SharedGraph::new(EntityGraph::new());
366 let meta = SnapshotReader::load(&snap_path, &graph2).unwrap();
367 assert_eq!(meta.entity_count, 0);
368 assert_eq!(graph2.len(), 0);
369 }
370
371 #[test]
372 fn find_latest_on_empty_dir() {
373 let dir = tempfile::tempdir().unwrap();
374 let result = SnapshotReader::find_latest(dir.path()).unwrap();
375 assert!(result.is_none());
376 }
377
378 #[test]
379 fn too_short_data_rejected() {
380 let data = b"HLSS_short";
381 let graph = SharedGraph::new(EntityGraph::new());
382 let result = SnapshotReader::load_from_bytes(data, Path::new("x.hlss"), &graph);
383 assert!(matches!(result, Err(SnapshotError::InvalidMagic)));
384 }
385
386 #[test]
387 fn unsupported_version_rejected() {
388 let mut data = Vec::new();
390 data.extend_from_slice(b"HLSS");
391 data.extend_from_slice(&99u16.to_le_bytes());
392 data.extend_from_slice(&0u32.to_le_bytes());
393 data.extend_from_slice(&0i64.to_le_bytes());
394 data.extend_from_slice(&0u64.to_le_bytes());
395 data.extend_from_slice(&[0u8; 4]); let crc = crc32fast::hash(&data);
398 data.extend_from_slice(&crc.to_le_bytes());
399
400 let graph = SharedGraph::new(EntityGraph::new());
401 let result = SnapshotReader::load_from_bytes(&data, Path::new("x.hlss"), &graph);
402 assert!(matches!(result, Err(SnapshotError::UnsupportedVersion(99))));
403 }
404}