tx2_pack/
storage.rs

1use crate::error::{PackError, Result};
2use crate::format::{PackedSnapshot, SnapshotHeader, PackFormat};
3use crate::compression::{CompressionCodec, compress, decompress};
4use crate::metadata::SnapshotMetadata;
5use std::path::{Path, PathBuf};
6use std::fs::File;
7use std::io::{Write, Read};
8use sha2::{Sha256, Digest};
9
10#[cfg(feature = "encryption")]
11use crate::encryption::{EncryptionKey, encrypt_snapshot, decrypt_snapshot};
12
13pub struct SnapshotWriter {
14    compression: CompressionCodec,
15    #[cfg(feature = "encryption")]
16    encryption_key: Option<EncryptionKey>,
17}
18
19impl SnapshotWriter {
20    pub fn new() -> Self {
21        Self {
22            compression: CompressionCodec::zstd_default(),
23            #[cfg(feature = "encryption")]
24            encryption_key: None,
25        }
26    }
27
28    pub fn with_compression(mut self, codec: CompressionCodec) -> Self {
29        self.compression = codec;
30        self
31    }
32
33    #[cfg(feature = "encryption")]
34    pub fn with_encryption(mut self, key: EncryptionKey) -> Self {
35        self.encryption_key = Some(key);
36        self
37    }
38
39    pub fn write_to_file<P: AsRef<Path>>(
40        &self,
41        snapshot: &PackedSnapshot,
42        path: P,
43    ) -> Result<()> {
44        let serialized = self.serialize_snapshot(snapshot)?;
45
46        let compressed = compress(&serialized, self.compression)?;
47
48        #[cfg(feature = "encryption")]
49        let final_data = if let Some(key) = &self.encryption_key {
50            encrypt_snapshot(&compressed, key)?
51        } else {
52            compressed
53        };
54
55        #[cfg(not(feature = "encryption"))]
56        let final_data = compressed;
57
58        let mut header = snapshot.header.clone();
59        header.compression = self.compression.into();
60
61        #[cfg(feature = "encryption")]
62        {
63            header.encrypted = self.encryption_key.is_some();
64        }
65
66        header.checksum = self.compute_checksum(&final_data);
67        header.data_size = final_data.len() as u64;
68
69        let header_bytes = bincode::serialize(&header)?;
70        header.data_offset = header_bytes.len() as u64;
71
72        let final_header_bytes = bincode::serialize(&header)?;
73
74        let mut file = File::create(path)?;
75
76        file.write_all(&final_header_bytes)?;
77
78        file.write_all(&final_data)?;
79
80        file.sync_all()?;
81
82        Ok(())
83    }
84
85    pub fn write_to_bytes(&self, snapshot: &PackedSnapshot) -> Result<Vec<u8>> {
86        let serialized = self.serialize_snapshot(snapshot)?;
87
88        let compressed = compress(&serialized, self.compression)?;
89
90        #[cfg(feature = "encryption")]
91        let final_data = if let Some(key) = &self.encryption_key {
92            encrypt_snapshot(&compressed, key)?
93        } else {
94            compressed
95        };
96
97        #[cfg(not(feature = "encryption"))]
98        let final_data = compressed;
99
100        let mut header = snapshot.header.clone();
101        header.compression = self.compression.into();
102
103        #[cfg(feature = "encryption")]
104        {
105            header.encrypted = self.encryption_key.is_some();
106        }
107
108        header.checksum = self.compute_checksum(&final_data);
109        header.data_size = final_data.len() as u64;
110
111        let header_bytes = bincode::serialize(&header)?;
112        header.data_offset = header_bytes.len() as u64;
113
114        let final_header_bytes = bincode::serialize(&header)?;
115
116        let mut result = Vec::with_capacity(final_header_bytes.len() + final_data.len());
117        result.extend_from_slice(&final_header_bytes);
118        result.extend_from_slice(&final_data);
119
120        Ok(result)
121    }
122
123    fn serialize_snapshot(&self, snapshot: &PackedSnapshot) -> Result<Vec<u8>> {
124        match snapshot.header.format {
125            PackFormat::Bincode => {
126                bincode::serialize(snapshot)
127                    .map_err(|e| PackError::Serialization(e.to_string()))
128            }
129            PackFormat::MessagePack => {
130                rmp_serde::to_vec(snapshot)
131                    .map_err(|e| PackError::Serialization(e.to_string()))
132            }
133            PackFormat::Custom => {
134                Err(PackError::Serialization("Custom format not implemented".to_string()))
135            }
136        }
137    }
138
139    fn compute_checksum(&self, data: &[u8]) -> [u8; 32] {
140        let mut hasher = Sha256::new();
141        hasher.update(data);
142        hasher.finalize().into()
143    }
144}
145
146impl Default for SnapshotWriter {
147    fn default() -> Self {
148        Self::new()
149    }
150}
151
152pub struct SnapshotReader {
153    #[cfg(feature = "encryption")]
154    encryption_key: Option<EncryptionKey>,
155}
156
157impl SnapshotReader {
158    pub fn new() -> Self {
159        Self {
160            #[cfg(feature = "encryption")]
161            encryption_key: None,
162        }
163    }
164
165    #[cfg(feature = "encryption")]
166    pub fn with_encryption(mut self, key: EncryptionKey) -> Self {
167        self.encryption_key = Some(key);
168        self
169    }
170
171    pub fn read_from_file<P: AsRef<Path>>(&self, path: P) -> Result<PackedSnapshot> {
172        let mut file = File::open(path)?;
173
174        let mut all_data = Vec::new();
175        file.read_to_end(&mut all_data)?;
176
177        let header: SnapshotHeader = bincode::deserialize(&all_data)?;
178        header.validate()?;
179
180        let data_start = header.data_offset as usize;
181        let data_end = data_start + header.data_size as usize;
182
183        if data_end > all_data.len() {
184            return Err(PackError::InvalidFormat(
185                format!("Data end {} exceeds file length {}", data_end, all_data.len())
186            ));
187        }
188
189        let data = &all_data[data_start..data_end];
190
191        self.verify_checksum(data, &header.checksum)?;
192
193        let decompressed = if header.encrypted {
194            #[cfg(feature = "encryption")]
195            {
196                let key = self.encryption_key.as_ref()
197                    .ok_or_else(|| PackError::Decryption("No encryption key provided".to_string()))?;
198                let decrypted = decrypt_snapshot(data, key)?;
199                decompress(&decrypted, header.compression)?
200            }
201
202            #[cfg(not(feature = "encryption"))]
203            {
204                return Err(PackError::Decryption("Snapshot is encrypted but encryption feature is disabled".to_string()));
205            }
206        } else {
207            decompress(data, header.compression)?
208        };
209
210        self.deserialize_snapshot(&decompressed, header.format)
211    }
212
213    pub fn read_from_bytes(&self, bytes: &[u8]) -> Result<PackedSnapshot> {
214        let header: SnapshotHeader = bincode::deserialize(bytes)?;
215        header.validate()?;
216
217        let data_start = header.data_offset as usize;
218        let data_end = data_start + header.data_size as usize;
219
220        if data_end > bytes.len() {
221            return Err(PackError::InvalidFormat(
222                format!("Data end {} exceeds buffer length {}", data_end, bytes.len())
223            ));
224        }
225
226        let data = &bytes[data_start..data_end];
227
228        self.verify_checksum(data, &header.checksum)?;
229
230        let decompressed = if header.encrypted {
231            #[cfg(feature = "encryption")]
232            {
233                let key = self.encryption_key.as_ref()
234                    .ok_or_else(|| PackError::Decryption("No encryption key provided".to_string()))?;
235                let decrypted = decrypt_snapshot(data, key)?;
236                decompress(&decrypted, header.compression)?
237            }
238
239            #[cfg(not(feature = "encryption"))]
240            {
241                return Err(PackError::Decryption("Snapshot is encrypted but encryption feature is disabled".to_string()));
242            }
243        } else {
244            decompress(data, header.compression)?
245        };
246
247        self.deserialize_snapshot(&decompressed, header.format)
248    }
249
250    fn deserialize_snapshot(&self, data: &[u8], format: PackFormat) -> Result<PackedSnapshot> {
251        match format {
252            PackFormat::Bincode => {
253                bincode::deserialize(data)
254                    .map_err(|e| PackError::Deserialization(e.to_string()))
255            }
256            PackFormat::MessagePack => {
257                rmp_serde::from_slice(data)
258                    .map_err(|e| PackError::Deserialization(e.to_string()))
259            }
260            PackFormat::Custom => {
261                Err(PackError::Deserialization("Custom format not implemented".to_string()))
262            }
263        }
264    }
265
266    fn verify_checksum(&self, data: &[u8], expected: &[u8; 32]) -> Result<()> {
267        let mut hasher = Sha256::new();
268        hasher.update(data);
269        let actual: [u8; 32] = hasher.finalize().into();
270
271        if &actual != expected {
272            return Err(PackError::ChecksumMismatch);
273        }
274
275        Ok(())
276    }
277}
278
279impl Default for SnapshotReader {
280    fn default() -> Self {
281        Self::new()
282    }
283}
284
285pub struct SnapshotStore {
286    root_dir: PathBuf,
287}
288
289impl SnapshotStore {
290    pub fn new<P: AsRef<Path>>(root_dir: P) -> Result<Self> {
291        let root_dir = root_dir.as_ref().to_path_buf();
292        std::fs::create_dir_all(&root_dir)?;
293
294        Ok(Self { root_dir })
295    }
296
297    pub fn save(
298        &self,
299        snapshot: &PackedSnapshot,
300        metadata: &SnapshotMetadata,
301        writer: &SnapshotWriter,
302    ) -> Result<PathBuf> {
303        let filename = format!("{}.tx2pack", metadata.id);
304        let path = self.root_dir.join(&filename);
305
306        writer.write_to_file(snapshot, &path)?;
307
308        let metadata_path = self.root_dir.join(format!("{}.meta.json", metadata.id));
309        let metadata_json = serde_json::to_string_pretty(metadata)?;
310        std::fs::write(metadata_path, metadata_json)?;
311
312        Ok(path)
313    }
314
315    pub fn load(&self, id: &str, reader: &SnapshotReader) -> Result<(PackedSnapshot, SnapshotMetadata)> {
316        let filename = format!("{}.tx2pack", id);
317        let path = self.root_dir.join(&filename);
318
319        if !path.exists() {
320            return Err(PackError::SnapshotNotFound(id.to_string()));
321        }
322
323        let snapshot = reader.read_from_file(&path)?;
324
325        let metadata_path = self.root_dir.join(format!("{}.meta.json", id));
326        let metadata = if metadata_path.exists() {
327            let metadata_json = std::fs::read_to_string(metadata_path)?;
328            serde_json::from_str(&metadata_json)?
329        } else {
330            SnapshotMetadata::new(id.to_string())
331        };
332
333        Ok((snapshot, metadata))
334    }
335
336    pub fn delete(&self, id: &str) -> Result<()> {
337        let filename = format!("{}.tx2pack", id);
338        let path = self.root_dir.join(&filename);
339
340        if path.exists() {
341            std::fs::remove_file(path)?;
342        }
343
344        let metadata_path = self.root_dir.join(format!("{}.meta.json", id));
345        if metadata_path.exists() {
346            std::fs::remove_file(metadata_path)?;
347        }
348
349        Ok(())
350    }
351
352    pub fn list(&self) -> Result<Vec<String>> {
353        let mut snapshots = Vec::new();
354
355        for entry in std::fs::read_dir(&self.root_dir)? {
356            let entry = entry?;
357            let path = entry.path();
358
359            if let Some(ext) = path.extension() {
360                if ext == "tx2pack" {
361                    if let Some(stem) = path.file_stem() {
362                        snapshots.push(stem.to_string_lossy().to_string());
363                    }
364                }
365            }
366        }
367
368        Ok(snapshots)
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use crate::format::PackedSnapshot;
376    use tempfile::TempDir;
377
378    #[test]
379    fn test_write_read_snapshot() {
380        let snapshot = PackedSnapshot::new();
381
382        let writer = SnapshotWriter::new();
383        let bytes = writer.write_to_bytes(&snapshot).unwrap();
384
385        let reader = SnapshotReader::new();
386        let loaded = reader.read_from_bytes(&bytes).unwrap();
387
388        assert_eq!(snapshot.header.version, loaded.header.version);
389    }
390
391    #[test]
392    fn test_snapshot_store() {
393        let temp_dir = TempDir::new().unwrap();
394        let store = SnapshotStore::new(temp_dir.path()).unwrap();
395
396        let snapshot = PackedSnapshot::new();
397        let metadata = SnapshotMetadata::new("test-snapshot".to_string());
398
399        let writer = SnapshotWriter::new();
400        store.save(&snapshot, &metadata, &writer).unwrap();
401
402        let snapshots = store.list().unwrap();
403        assert!(snapshots.contains(&"test-snapshot".to_string()));
404
405        let reader = SnapshotReader::new();
406        let (loaded, loaded_meta) = store.load("test-snapshot", &reader).unwrap();
407
408        assert_eq!(snapshot.header.version, loaded.header.version);
409        assert_eq!(metadata.id, loaded_meta.id);
410
411        store.delete("test-snapshot").unwrap();
412        let snapshots = store.list().unwrap();
413        assert!(!snapshots.contains(&"test-snapshot".to_string()));
414    }
415
416    #[cfg(feature = "encryption")]
417    #[test]
418    fn test_encrypted_snapshot() {
419        use crate::encryption::EncryptionKey;
420
421        let snapshot = PackedSnapshot::new();
422        let key = EncryptionKey::generate();
423
424        let writer = SnapshotWriter::new().with_encryption(key.clone());
425        let bytes = writer.write_to_bytes(&snapshot).unwrap();
426
427        let reader = SnapshotReader::new().with_encryption(key);
428        let loaded = reader.read_from_bytes(&bytes).unwrap();
429
430        assert_eq!(snapshot.header.version, loaded.header.version);
431    }
432}