1use crate::error::{PackError, Result};
2use crate::format::{PackedSnapshot, SnapshotHeader, PackFormat};
3use crate::compression::{CompressionCodec, compress, decompress};
4use crate::metadata::SnapshotMetadata;
5use std::path::{Path, PathBuf};
6use std::fs::File;
7use std::io::{Write, Read};
8use sha2::{Sha256, Digest};
9
10#[cfg(feature = "encryption")]
11use crate::encryption::{EncryptionKey, encrypt_snapshot, decrypt_snapshot};
12
13pub struct SnapshotWriter {
14 compression: CompressionCodec,
15 #[cfg(feature = "encryption")]
16 encryption_key: Option<EncryptionKey>,
17}
18
19impl SnapshotWriter {
20 pub fn new() -> Self {
21 Self {
22 compression: CompressionCodec::zstd_default(),
23 #[cfg(feature = "encryption")]
24 encryption_key: None,
25 }
26 }
27
28 pub fn with_compression(mut self, codec: CompressionCodec) -> Self {
29 self.compression = codec;
30 self
31 }
32
33 #[cfg(feature = "encryption")]
34 pub fn with_encryption(mut self, key: EncryptionKey) -> Self {
35 self.encryption_key = Some(key);
36 self
37 }
38
39 pub fn write_to_file<P: AsRef<Path>>(
40 &self,
41 snapshot: &PackedSnapshot,
42 path: P,
43 ) -> Result<()> {
44 let serialized = self.serialize_snapshot(snapshot)?;
45
46 let compressed = compress(&serialized, self.compression)?;
47
48 #[cfg(feature = "encryption")]
49 let final_data = if let Some(key) = &self.encryption_key {
50 encrypt_snapshot(&compressed, key)?
51 } else {
52 compressed
53 };
54
55 #[cfg(not(feature = "encryption"))]
56 let final_data = compressed;
57
58 let mut header = snapshot.header.clone();
59 header.compression = self.compression.into();
60
61 #[cfg(feature = "encryption")]
62 {
63 header.encrypted = self.encryption_key.is_some();
64 }
65
66 header.checksum = self.compute_checksum(&final_data);
67 header.data_size = final_data.len() as u64;
68
69 let header_bytes = bincode::serialize(&header)?;
70 header.data_offset = header_bytes.len() as u64;
71
72 let final_header_bytes = bincode::serialize(&header)?;
73
74 let mut file = File::create(path)?;
75
76 file.write_all(&final_header_bytes)?;
77
78 file.write_all(&final_data)?;
79
80 file.sync_all()?;
81
82 Ok(())
83 }
84
85 pub fn write_to_bytes(&self, snapshot: &PackedSnapshot) -> Result<Vec<u8>> {
86 let serialized = self.serialize_snapshot(snapshot)?;
87
88 let compressed = compress(&serialized, self.compression)?;
89
90 #[cfg(feature = "encryption")]
91 let final_data = if let Some(key) = &self.encryption_key {
92 encrypt_snapshot(&compressed, key)?
93 } else {
94 compressed
95 };
96
97 #[cfg(not(feature = "encryption"))]
98 let final_data = compressed;
99
100 let mut header = snapshot.header.clone();
101 header.compression = self.compression.into();
102
103 #[cfg(feature = "encryption")]
104 {
105 header.encrypted = self.encryption_key.is_some();
106 }
107
108 header.checksum = self.compute_checksum(&final_data);
109 header.data_size = final_data.len() as u64;
110
111 let header_bytes = bincode::serialize(&header)?;
112 header.data_offset = header_bytes.len() as u64;
113
114 let final_header_bytes = bincode::serialize(&header)?;
115
116 let mut result = Vec::with_capacity(final_header_bytes.len() + final_data.len());
117 result.extend_from_slice(&final_header_bytes);
118 result.extend_from_slice(&final_data);
119
120 Ok(result)
121 }
122
123 fn serialize_snapshot(&self, snapshot: &PackedSnapshot) -> Result<Vec<u8>> {
124 match snapshot.header.format {
125 PackFormat::Bincode => {
126 bincode::serialize(snapshot)
127 .map_err(|e| PackError::Serialization(e.to_string()))
128 }
129 PackFormat::MessagePack => {
130 rmp_serde::to_vec(snapshot)
131 .map_err(|e| PackError::Serialization(e.to_string()))
132 }
133 PackFormat::Custom => {
134 Err(PackError::Serialization("Custom format not implemented".to_string()))
135 }
136 }
137 }
138
139 fn compute_checksum(&self, data: &[u8]) -> [u8; 32] {
140 let mut hasher = Sha256::new();
141 hasher.update(data);
142 hasher.finalize().into()
143 }
144}
145
146impl Default for SnapshotWriter {
147 fn default() -> Self {
148 Self::new()
149 }
150}
151
152pub struct SnapshotReader {
153 #[cfg(feature = "encryption")]
154 encryption_key: Option<EncryptionKey>,
155}
156
157impl SnapshotReader {
158 pub fn new() -> Self {
159 Self {
160 #[cfg(feature = "encryption")]
161 encryption_key: None,
162 }
163 }
164
165 #[cfg(feature = "encryption")]
166 pub fn with_encryption(mut self, key: EncryptionKey) -> Self {
167 self.encryption_key = Some(key);
168 self
169 }
170
171 pub fn read_from_file<P: AsRef<Path>>(&self, path: P) -> Result<PackedSnapshot> {
172 let mut file = File::open(path)?;
173
174 let mut all_data = Vec::new();
175 file.read_to_end(&mut all_data)?;
176
177 let header: SnapshotHeader = bincode::deserialize(&all_data)?;
178 header.validate()?;
179
180 let data_start = header.data_offset as usize;
181 let data_end = data_start + header.data_size as usize;
182
183 if data_end > all_data.len() {
184 return Err(PackError::InvalidFormat(
185 format!("Data end {} exceeds file length {}", data_end, all_data.len())
186 ));
187 }
188
189 let data = &all_data[data_start..data_end];
190
191 self.verify_checksum(data, &header.checksum)?;
192
193 let decompressed = if header.encrypted {
194 #[cfg(feature = "encryption")]
195 {
196 let key = self.encryption_key.as_ref()
197 .ok_or_else(|| PackError::Decryption("No encryption key provided".to_string()))?;
198 let decrypted = decrypt_snapshot(data, key)?;
199 decompress(&decrypted, header.compression)?
200 }
201
202 #[cfg(not(feature = "encryption"))]
203 {
204 return Err(PackError::Decryption("Snapshot is encrypted but encryption feature is disabled".to_string()));
205 }
206 } else {
207 decompress(data, header.compression)?
208 };
209
210 self.deserialize_snapshot(&decompressed, header.format)
211 }
212
213 pub fn read_from_bytes(&self, bytes: &[u8]) -> Result<PackedSnapshot> {
214 let header: SnapshotHeader = bincode::deserialize(bytes)?;
215 header.validate()?;
216
217 let data_start = header.data_offset as usize;
218 let data_end = data_start + header.data_size as usize;
219
220 if data_end > bytes.len() {
221 return Err(PackError::InvalidFormat(
222 format!("Data end {} exceeds buffer length {}", data_end, bytes.len())
223 ));
224 }
225
226 let data = &bytes[data_start..data_end];
227
228 self.verify_checksum(data, &header.checksum)?;
229
230 let decompressed = if header.encrypted {
231 #[cfg(feature = "encryption")]
232 {
233 let key = self.encryption_key.as_ref()
234 .ok_or_else(|| PackError::Decryption("No encryption key provided".to_string()))?;
235 let decrypted = decrypt_snapshot(data, key)?;
236 decompress(&decrypted, header.compression)?
237 }
238
239 #[cfg(not(feature = "encryption"))]
240 {
241 return Err(PackError::Decryption("Snapshot is encrypted but encryption feature is disabled".to_string()));
242 }
243 } else {
244 decompress(data, header.compression)?
245 };
246
247 self.deserialize_snapshot(&decompressed, header.format)
248 }
249
250 fn deserialize_snapshot(&self, data: &[u8], format: PackFormat) -> Result<PackedSnapshot> {
251 match format {
252 PackFormat::Bincode => {
253 bincode::deserialize(data)
254 .map_err(|e| PackError::Deserialization(e.to_string()))
255 }
256 PackFormat::MessagePack => {
257 rmp_serde::from_slice(data)
258 .map_err(|e| PackError::Deserialization(e.to_string()))
259 }
260 PackFormat::Custom => {
261 Err(PackError::Deserialization("Custom format not implemented".to_string()))
262 }
263 }
264 }
265
266 fn verify_checksum(&self, data: &[u8], expected: &[u8; 32]) -> Result<()> {
267 let mut hasher = Sha256::new();
268 hasher.update(data);
269 let actual: [u8; 32] = hasher.finalize().into();
270
271 if &actual != expected {
272 return Err(PackError::ChecksumMismatch);
273 }
274
275 Ok(())
276 }
277}
278
279impl Default for SnapshotReader {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285pub struct SnapshotStore {
286 root_dir: PathBuf,
287}
288
289impl SnapshotStore {
290 pub fn new<P: AsRef<Path>>(root_dir: P) -> Result<Self> {
291 let root_dir = root_dir.as_ref().to_path_buf();
292 std::fs::create_dir_all(&root_dir)?;
293
294 Ok(Self { root_dir })
295 }
296
297 pub fn save(
298 &self,
299 snapshot: &PackedSnapshot,
300 metadata: &SnapshotMetadata,
301 writer: &SnapshotWriter,
302 ) -> Result<PathBuf> {
303 let filename = format!("{}.tx2pack", metadata.id);
304 let path = self.root_dir.join(&filename);
305
306 writer.write_to_file(snapshot, &path)?;
307
308 let metadata_path = self.root_dir.join(format!("{}.meta.json", metadata.id));
309 let metadata_json = serde_json::to_string_pretty(metadata)?;
310 std::fs::write(metadata_path, metadata_json)?;
311
312 Ok(path)
313 }
314
315 pub fn load(&self, id: &str, reader: &SnapshotReader) -> Result<(PackedSnapshot, SnapshotMetadata)> {
316 let filename = format!("{}.tx2pack", id);
317 let path = self.root_dir.join(&filename);
318
319 if !path.exists() {
320 return Err(PackError::SnapshotNotFound(id.to_string()));
321 }
322
323 let snapshot = reader.read_from_file(&path)?;
324
325 let metadata_path = self.root_dir.join(format!("{}.meta.json", id));
326 let metadata = if metadata_path.exists() {
327 let metadata_json = std::fs::read_to_string(metadata_path)?;
328 serde_json::from_str(&metadata_json)?
329 } else {
330 SnapshotMetadata::new(id.to_string())
331 };
332
333 Ok((snapshot, metadata))
334 }
335
336 pub fn delete(&self, id: &str) -> Result<()> {
337 let filename = format!("{}.tx2pack", id);
338 let path = self.root_dir.join(&filename);
339
340 if path.exists() {
341 std::fs::remove_file(path)?;
342 }
343
344 let metadata_path = self.root_dir.join(format!("{}.meta.json", id));
345 if metadata_path.exists() {
346 std::fs::remove_file(metadata_path)?;
347 }
348
349 Ok(())
350 }
351
352 pub fn list(&self) -> Result<Vec<String>> {
353 let mut snapshots = Vec::new();
354
355 for entry in std::fs::read_dir(&self.root_dir)? {
356 let entry = entry?;
357 let path = entry.path();
358
359 if let Some(ext) = path.extension() {
360 if ext == "tx2pack" {
361 if let Some(stem) = path.file_stem() {
362 snapshots.push(stem.to_string_lossy().to_string());
363 }
364 }
365 }
366 }
367
368 Ok(snapshots)
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use crate::format::PackedSnapshot;
376 use tempfile::TempDir;
377
378 #[test]
379 fn test_write_read_snapshot() {
380 let snapshot = PackedSnapshot::new();
381
382 let writer = SnapshotWriter::new();
383 let bytes = writer.write_to_bytes(&snapshot).unwrap();
384
385 let reader = SnapshotReader::new();
386 let loaded = reader.read_from_bytes(&bytes).unwrap();
387
388 assert_eq!(snapshot.header.version, loaded.header.version);
389 }
390
391 #[test]
392 fn test_snapshot_store() {
393 let temp_dir = TempDir::new().unwrap();
394 let store = SnapshotStore::new(temp_dir.path()).unwrap();
395
396 let snapshot = PackedSnapshot::new();
397 let metadata = SnapshotMetadata::new("test-snapshot".to_string());
398
399 let writer = SnapshotWriter::new();
400 store.save(&snapshot, &metadata, &writer).unwrap();
401
402 let snapshots = store.list().unwrap();
403 assert!(snapshots.contains(&"test-snapshot".to_string()));
404
405 let reader = SnapshotReader::new();
406 let (loaded, loaded_meta) = store.load("test-snapshot", &reader).unwrap();
407
408 assert_eq!(snapshot.header.version, loaded.header.version);
409 assert_eq!(metadata.id, loaded_meta.id);
410
411 store.delete("test-snapshot").unwrap();
412 let snapshots = store.list().unwrap();
413 assert!(!snapshots.contains(&"test-snapshot".to_string()));
414 }
415
416 #[cfg(feature = "encryption")]
417 #[test]
418 fn test_encrypted_snapshot() {
419 use crate::encryption::EncryptionKey;
420
421 let snapshot = PackedSnapshot::new();
422 let key = EncryptionKey::generate();
423
424 let writer = SnapshotWriter::new().with_encryption(key.clone());
425 let bytes = writer.write_to_bytes(&snapshot).unwrap();
426
427 let reader = SnapshotReader::new().with_encryption(key);
428 let loaded = reader.read_from_bytes(&bytes).unwrap();
429
430 assert_eq!(snapshot.header.version, loaded.header.version);
431 }
432}