Skip to main content

velesdb_core/agent/
snapshot.rs

1//! Snapshot and versioning support for `AgentMemory`.
2//!
3//! Provides serialization/deserialization of `AgentMemory` state for:
4//! - Persistence across restarts
5//! - Rollback to previous versions
6//! - State transfer between instances
7//!
8//! # Snapshot Format
9//!
10//! ```text
11//! [Magic: "VAMM" 4 bytes]
12//! [Version: 1 byte]
13//! [Semantic state length: 8 bytes]
14//! [Semantic state: N bytes]
15//! [Episodic state length: 8 bytes]
16//! [Episodic state: N bytes]
17//! [Procedural state length: 8 bytes]
18//! [Procedural state: N bytes]
19//! [TTL state length: 8 bytes]
20//! [TTL state: N bytes]
21//! [CRC32: 4 bytes]
22//! ```
23
24// SAFETY: Numeric casts in snapshot handling are intentional:
25// - usize to u32 in CRC32: i ranges 0-255, always fits in u32
26// - u64 to usize for lengths: Snapshot data is created/loaded on same architecture
27//   or architecture-compatible data. Lengths are validated before use.
28// All length values are bounds-checked against data.len() before array access.
29#![allow(clippy::cast_possible_truncation)]
30
31use std::fs::File;
32use std::io::{self, Read, Write};
33use std::path::Path;
34
35/// Snapshot file magic bytes for `AgentMemory`.
36pub const SNAPSHOT_MAGIC: &[u8; 4] = b"VAMM";
37
38/// Current snapshot format version.
39pub const SNAPSHOT_VERSION: u8 = 1;
40
41/// Simple CRC32 implementation (IEEE 802.3 polynomial).
42#[inline]
43fn crc32_hash(data: &[u8]) -> u32 {
44    const CRC32_TABLE: [u32; 256] = {
45        let mut table = [0u32; 256];
46        let mut i = 0;
47        while i < 256 {
48            let mut crc = i as u32;
49            let mut j = 0;
50            while j < 8 {
51                if crc & 1 != 0 {
52                    crc = (crc >> 1) ^ 0xEDB8_8320;
53                } else {
54                    crc >>= 1;
55                }
56                j += 1;
57            }
58            table[i] = crc;
59            i += 1;
60        }
61        table
62    };
63
64    let mut crc = 0xFFFF_FFFF_u32;
65    for &byte in data {
66        let idx = ((crc ^ u32::from(byte)) & 0xFF) as usize;
67        crc = (crc >> 8) ^ CRC32_TABLE[idx];
68    }
69    !crc
70}
71
72/// Memory state for serialization.
73#[derive(Debug, Clone, Default)]
74pub struct MemoryState {
75    /// Serialized semantic memory entries.
76    pub semantic: Vec<u8>,
77    /// Serialized episodic memory entries.
78    pub episodic: Vec<u8>,
79    /// Serialized procedural memory entries.
80    pub procedural: Vec<u8>,
81    /// Serialized TTL state.
82    pub ttl: Vec<u8>,
83}
84
85/// Snapshot metadata.
86#[derive(Debug, Clone)]
87pub struct SnapshotMetadata {
88    /// Snapshot format version.
89    pub version: u8,
90    /// Total size in bytes.
91    pub total_size: usize,
92    /// CRC32 checksum.
93    pub checksum: u32,
94}
95
96/// Error type for snapshot operations.
97#[derive(Debug)]
98pub enum SnapshotError {
99    /// IO error during read/write.
100    Io(io::Error),
101    /// Invalid magic bytes.
102    InvalidMagic,
103    /// Unsupported version.
104    UnsupportedVersion(u8),
105    /// CRC checksum mismatch.
106    ChecksumMismatch {
107        /// Expected CRC32 value stored in the snapshot.
108        expected: u32,
109        /// Actual CRC32 value computed from the data.
110        actual: u32,
111    },
112    /// Data corruption or truncation.
113    CorruptedData(String),
114}
115
116impl std::fmt::Display for SnapshotError {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        match self {
119            Self::Io(e) => write!(f, "IO error: {e}"),
120            Self::InvalidMagic => write!(f, "Invalid snapshot magic bytes"),
121            Self::UnsupportedVersion(v) => write!(f, "Unsupported snapshot version: {v}"),
122            Self::ChecksumMismatch { expected, actual } => {
123                write!(
124                    f,
125                    "Checksum mismatch: expected {expected:08x}, got {actual:08x}"
126                )
127            }
128            Self::CorruptedData(msg) => write!(f, "Corrupted data: {msg}"),
129        }
130    }
131}
132
133impl std::error::Error for SnapshotError {
134    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
135        match self {
136            Self::Io(e) => Some(e),
137            _ => None,
138        }
139    }
140}
141
142impl From<io::Error> for SnapshotError {
143    fn from(e: io::Error) -> Self {
144        Self::Io(e)
145    }
146}
147
148/// Creates a snapshot from memory state.
149///
150/// # Arguments
151///
152/// * `state` - Memory state to serialize
153///
154/// # Returns
155///
156/// Serialized snapshot bytes.
157#[must_use]
158pub fn create_snapshot(state: &MemoryState) -> Vec<u8> {
159    let total_size = 4
160        + 1
161        + 8
162        + state.semantic.len()
163        + 8
164        + state.episodic.len()
165        + 8
166        + state.procedural.len()
167        + 8
168        + state.ttl.len()
169        + 4;
170    let mut buf = Vec::with_capacity(total_size);
171
172    buf.extend_from_slice(SNAPSHOT_MAGIC);
173    buf.push(SNAPSHOT_VERSION);
174
175    buf.extend_from_slice(&(state.semantic.len() as u64).to_le_bytes());
176    buf.extend_from_slice(&state.semantic);
177
178    buf.extend_from_slice(&(state.episodic.len() as u64).to_le_bytes());
179    buf.extend_from_slice(&state.episodic);
180
181    buf.extend_from_slice(&(state.procedural.len() as u64).to_le_bytes());
182    buf.extend_from_slice(&state.procedural);
183
184    buf.extend_from_slice(&(state.ttl.len() as u64).to_le_bytes());
185    buf.extend_from_slice(&state.ttl);
186
187    let crc = crc32_hash(&buf);
188    buf.extend_from_slice(&crc.to_le_bytes());
189
190    buf
191}
192
193/// Loads a snapshot from bytes.
194///
195/// # Arguments
196///
197/// * `data` - Snapshot bytes
198///
199/// # Errors
200///
201/// Returns error if snapshot is invalid or corrupted.
202pub fn load_snapshot(data: &[u8]) -> Result<MemoryState, SnapshotError> {
203    validate_snapshot_header(data)?;
204
205    let mut offset = 5; // skip magic (4) + version (1)
206    let payload_end = data.len() - 4; // exclude trailing CRC
207
208    let semantic = read_section(data, &mut offset, payload_end, "Semantic")?;
209    let episodic = read_section(data, &mut offset, payload_end, "Episodic")?;
210    let procedural = read_section(data, &mut offset, payload_end, "Procedural")?;
211    let ttl = read_section(data, &mut offset, payload_end, "TTL")?;
212
213    Ok(MemoryState {
214        semantic,
215        episodic,
216        procedural,
217        ttl,
218    })
219}
220
221/// Validates magic bytes, version, and CRC32 checksum of a snapshot.
222fn validate_snapshot_header(data: &[u8]) -> Result<(), SnapshotError> {
223    const MIN_SIZE: usize = 4 + 1 + 8 + 8 + 8 + 8 + 4;
224
225    if data.len() < MIN_SIZE {
226        return Err(SnapshotError::CorruptedData(
227            "Snapshot too small".to_string(),
228        ));
229    }
230    if &data[0..4] != SNAPSHOT_MAGIC {
231        return Err(SnapshotError::InvalidMagic);
232    }
233    let version = data[4];
234    if version != SNAPSHOT_VERSION {
235        return Err(SnapshotError::UnsupportedVersion(version));
236    }
237
238    let stored_crc = u32::from_le_bytes(
239        data[data.len() - 4..]
240            .try_into()
241            .map_err(|_| SnapshotError::CorruptedData("Invalid CRC bytes".to_string()))?,
242    );
243    let computed_crc = crc32_hash(&data[..data.len() - 4]);
244    if stored_crc != computed_crc {
245        return Err(SnapshotError::ChecksumMismatch {
246            expected: stored_crc,
247            actual: computed_crc,
248        });
249    }
250    Ok(())
251}
252
253/// Reads a length-prefixed section from the snapshot data.
254fn read_section(
255    data: &[u8],
256    offset: &mut usize,
257    payload_end: usize,
258    label: &str,
259) -> Result<Vec<u8>, SnapshotError> {
260    let section_len = read_u64(&data[*offset..])? as usize;
261    *offset += 8;
262    if *offset + section_len > payload_end {
263        return Err(SnapshotError::CorruptedData(format!(
264            "{label} data truncated"
265        )));
266    }
267    let section = data[*offset..*offset + section_len].to_vec();
268    *offset += section_len;
269    Ok(section)
270}
271
272/// Saves a snapshot to a file.
273///
274/// Uses atomic write (temp file + rename) for safety.
275///
276/// # Errors
277///
278/// Returns error if file operations fail.
279pub fn save_snapshot_to_file<P: AsRef<Path>>(
280    path: P,
281    state: &MemoryState,
282) -> Result<(), SnapshotError> {
283    let path = path.as_ref();
284    let snapshot_data = create_snapshot(state);
285
286    let temp_path = path.with_extension("tmp");
287    let mut file = File::create(&temp_path)?;
288    file.write_all(&snapshot_data)?;
289    file.sync_all()?;
290    drop(file);
291
292    std::fs::rename(&temp_path, path)?;
293
294    Ok(())
295}
296
297/// Loads a snapshot from a file.
298///
299/// # Errors
300///
301/// Returns error if file operations fail or snapshot is invalid.
302pub fn load_snapshot_from_file<P: AsRef<Path>>(path: P) -> Result<MemoryState, SnapshotError> {
303    let mut file = File::open(path)?;
304    let mut data = Vec::new();
305    file.read_to_end(&mut data)?;
306    load_snapshot(&data)
307}
308
309/// Helper to read u64 from bytes.
310fn read_u64(data: &[u8]) -> Result<u64, SnapshotError> {
311    if data.len() < 8 {
312        return Err(SnapshotError::CorruptedData(
313            "Not enough bytes for u64".to_string(),
314        ));
315    }
316    Ok(u64::from_le_bytes(data[0..8].try_into().map_err(|_| {
317        SnapshotError::CorruptedData("Invalid u64 bytes".to_string())
318    })?))
319}
320
321/// Snapshot manager for versioned snapshots.
322pub struct SnapshotManager {
323    /// Base directory for snapshots.
324    base_path: std::path::PathBuf,
325    /// Maximum number of snapshots to retain.
326    max_snapshots: usize,
327}
328
329impl SnapshotManager {
330    /// Creates a new snapshot manager.
331    ///
332    /// # Arguments
333    ///
334    /// * `base_path` - Directory for storing snapshots
335    /// * `max_snapshots` - Maximum number of snapshots to retain
336    pub fn new<P: AsRef<Path>>(base_path: P, max_snapshots: usize) -> Self {
337        Self {
338            base_path: base_path.as_ref().to_path_buf(),
339            max_snapshots,
340        }
341    }
342
343    /// Creates a new versioned snapshot.
344    ///
345    /// # Returns
346    ///
347    /// The version number of the created snapshot.
348    ///
349    /// # Errors
350    ///
351    /// Returns error if file operations fail.
352    pub fn create_versioned_snapshot(&self, state: &MemoryState) -> Result<u64, SnapshotError> {
353        std::fs::create_dir_all(&self.base_path)?;
354
355        let version = self.next_version()?;
356        let filename = format!("snapshot_{version:08}.vamm");
357        let path = self.base_path.join(filename);
358
359        save_snapshot_to_file(&path, state)?;
360        self.cleanup_old_snapshots()?;
361
362        Ok(version)
363    }
364
365    /// Loads the latest snapshot.
366    ///
367    /// # Errors
368    ///
369    /// Returns error if no snapshots exist or loading fails.
370    pub fn load_latest(&self) -> Result<(u64, MemoryState), SnapshotError> {
371        let version = self
372            .latest_version()?
373            .ok_or_else(|| SnapshotError::CorruptedData("No snapshots found".to_string()))?;
374        let state = self.load_version(version)?;
375        Ok((version, state))
376    }
377
378    /// Loads a specific snapshot version.
379    ///
380    /// # Errors
381    ///
382    /// Returns error if version doesn't exist or loading fails.
383    pub fn load_version(&self, version: u64) -> Result<MemoryState, SnapshotError> {
384        let filename = format!("snapshot_{version:08}.vamm");
385        let path = self.base_path.join(filename);
386        load_snapshot_from_file(&path)
387    }
388
389    /// Lists all available snapshot versions.
390    ///
391    /// # Errors
392    ///
393    /// Returns error if directory operations fail.
394    pub fn list_versions(&self) -> Result<Vec<u64>, SnapshotError> {
395        if !self.base_path.exists() {
396            return Ok(Vec::new());
397        }
398
399        let mut versions = Vec::new();
400        for entry in std::fs::read_dir(&self.base_path)? {
401            let entry = entry?;
402            let filename = entry.file_name();
403            let filename_str = filename.to_string_lossy();
404
405            if filename_str.starts_with("snapshot_") && filename_str.ends_with(".vamm") {
406                if let Some(version_str) = filename_str
407                    .strip_prefix("snapshot_")
408                    .and_then(|s| s.strip_suffix(".vamm"))
409                {
410                    if let Ok(version) = version_str.parse::<u64>() {
411                        versions.push(version);
412                    }
413                }
414            }
415        }
416
417        versions.sort_unstable();
418        Ok(versions)
419    }
420
421    /// Returns the latest snapshot version.
422    fn latest_version(&self) -> Result<Option<u64>, SnapshotError> {
423        Ok(self.list_versions()?.into_iter().max())
424    }
425
426    /// Returns the next version number.
427    fn next_version(&self) -> Result<u64, SnapshotError> {
428        Ok(self.latest_version()?.map_or(1, |v| v + 1))
429    }
430
431    /// Removes old snapshots beyond the retention limit.
432    fn cleanup_old_snapshots(&self) -> Result<(), SnapshotError> {
433        let versions = self.list_versions()?;
434        if versions.len() <= self.max_snapshots {
435            return Ok(());
436        }
437
438        let to_remove = versions.len() - self.max_snapshots;
439        for version in versions.into_iter().take(to_remove) {
440            let filename = format!("snapshot_{version:08}.vamm");
441            let path = self.base_path.join(filename);
442            let _ = std::fs::remove_file(path);
443        }
444
445        Ok(())
446    }
447}