Skip to main content

zlayer_storage/
snapshot.rs

1//! Layer snapshot creation and extraction
2//!
3//! Handles tarball creation from `OverlayFS` upper layers with compression.
4
5use crate::error::{LayerStorageError, Result};
6use crate::types::LayerSnapshot;
7use sha2::{Digest, Sha256};
8use std::fs::File;
9use std::io::{BufReader, BufWriter, Read, Write};
10use std::path::Path;
11use tar::Builder;
12use tracing::{debug, info, instrument};
13
14/// Create a compressed tarball snapshot from a directory
15///
16/// Returns the snapshot metadata and path to the compressed tarball.
17///
18/// # Errors
19///
20/// Returns an error if the source directory cannot be read, the tarball cannot
21/// be created, or compression fails.
22#[instrument(skip(source_dir, output_path), fields(source = %source_dir.as_ref().display()))]
23pub fn create_snapshot(
24    source_dir: impl AsRef<Path>,
25    output_path: impl AsRef<Path>,
26    compression_level: i32,
27) -> Result<LayerSnapshot> {
28    let source_dir = source_dir.as_ref();
29    let output_path = output_path.as_ref();
30
31    info!("Creating layer snapshot from {}", source_dir.display());
32
33    // Create temporary uncompressed tarball first to calculate digest
34    let tar_temp_path = output_path.with_extension("tar.tmp");
35
36    // Build the tar archive
37    let tar_file = File::create(&tar_temp_path)?;
38    let mut tar_builder = Builder::new(BufWriter::new(tar_file));
39
40    let mut file_count = 0u64;
41    tar_builder.append_dir_all(".", source_dir)?;
42
43    // Finish writing the tar
44    tar_builder.into_inner()?.flush()?;
45
46    // Calculate SHA256 of uncompressed tar and count files
47    let mut hasher = Sha256::new();
48    let tar_file = File::open(&tar_temp_path)?;
49    let uncompressed_size = tar_file.metadata()?.len();
50    let mut reader = BufReader::new(tar_file);
51
52    // Count files while hashing
53    let mut buffer = [0u8; 8192];
54    loop {
55        let bytes_read = reader.read(&mut buffer)?;
56        if bytes_read == 0 {
57            break;
58        }
59        hasher.update(&buffer[..bytes_read]);
60    }
61
62    // Count entries in tar
63    let tar_file = File::open(&tar_temp_path)?;
64    let mut archive = tar::Archive::new(tar_file);
65    for entry in archive.entries()? {
66        let _ = entry?;
67        file_count += 1;
68    }
69
70    let digest = hex::encode(hasher.finalize());
71    debug!("Layer digest: {}", digest);
72
73    // Compress with zstd
74    let tar_file = File::open(&tar_temp_path)?;
75    let compressed_file = File::create(output_path)?;
76    let mut encoder =
77        zstd::stream::Encoder::new(BufWriter::new(compressed_file), compression_level)?;
78
79    let mut reader = BufReader::new(tar_file);
80    std::io::copy(&mut reader, &mut encoder)?;
81    encoder.finish()?.flush()?;
82
83    // Get compressed size
84    let compressed_size = std::fs::metadata(output_path)?.len();
85
86    // Clean up temp file
87    std::fs::remove_file(&tar_temp_path)?;
88
89    let snapshot = LayerSnapshot {
90        digest,
91        size_bytes: uncompressed_size,
92        compressed_size_bytes: compressed_size,
93        created_at: chrono::Utc::now(),
94        file_count,
95    };
96
97    #[allow(clippy::cast_precision_loss)]
98    let compression_pct = (1.0 - (compressed_size as f64 / uncompressed_size as f64)) * 100.0;
99    info!(
100        "Created snapshot: {} bytes -> {} bytes ({:.1}% compression), {} files",
101        uncompressed_size, compressed_size, compression_pct, file_count
102    );
103
104    Ok(snapshot)
105}
106
107/// Extract a compressed tarball snapshot to a directory
108///
109/// # Errors
110///
111/// Returns an error if decompression fails, the digest does not match, or
112/// extraction to the target directory fails.
113#[instrument(skip(tarball_path, target_dir), fields(tarball = %tarball_path.as_ref().display()))]
114pub fn extract_snapshot(
115    tarball_path: impl AsRef<Path>,
116    target_dir: impl AsRef<Path>,
117    expected_digest: Option<&str>,
118) -> Result<()> {
119    let tarball_path = tarball_path.as_ref();
120    let target_dir = target_dir.as_ref();
121
122    info!("Extracting layer snapshot to {}", target_dir.display());
123
124    // Decompress
125    let compressed_file = File::open(tarball_path)?;
126    let decoder = zstd::stream::Decoder::new(BufReader::new(compressed_file))?;
127
128    // If we need to verify digest, decompress to temp file first
129    if let Some(expected) = expected_digest {
130        let temp_tar = tarball_path.with_extension("tar.verify");
131        {
132            let mut temp_file = BufWriter::new(File::create(&temp_tar)?);
133            let mut decoder =
134                zstd::stream::Decoder::new(BufReader::new(File::open(tarball_path)?))?;
135            std::io::copy(&mut decoder, &mut temp_file)?;
136            temp_file.flush()?;
137        }
138
139        // Calculate digest
140        let mut hasher = Sha256::new();
141        let mut file = BufReader::new(File::open(&temp_tar)?);
142        let mut buffer = [0u8; 8192];
143        loop {
144            let bytes_read = file.read(&mut buffer)?;
145            if bytes_read == 0 {
146                break;
147            }
148            hasher.update(&buffer[..bytes_read]);
149        }
150
151        let actual_digest = hex::encode(hasher.finalize());
152        if actual_digest != expected {
153            std::fs::remove_file(&temp_tar)?;
154            return Err(LayerStorageError::ChecksumMismatch {
155                expected: expected.to_string(),
156                actual: actual_digest,
157            });
158        }
159
160        // Extract from verified temp file
161        let file = File::open(&temp_tar)?;
162        let mut archive = tar::Archive::new(file);
163        archive.unpack(target_dir)?;
164
165        std::fs::remove_file(&temp_tar)?;
166    } else {
167        // Extract directly without verification
168        let mut archive = tar::Archive::new(decoder);
169        archive.unpack(target_dir)?;
170    }
171
172    info!("Extraction complete");
173    Ok(())
174}
175
176/// Calculate the SHA256 digest of a directory's contents (for change detection)
177///
178/// # Errors
179///
180/// Returns an error if any file in the directory cannot be read.
181#[instrument(skip(dir), fields(dir = %dir.as_ref().display()))]
182pub fn calculate_directory_digest(dir: impl AsRef<Path>) -> Result<String> {
183    let dir = dir.as_ref();
184    let mut hasher = Sha256::new();
185
186    // Walk directory and hash file contents and metadata
187    fn hash_dir(hasher: &mut Sha256, dir: &Path, prefix: &Path) -> Result<()> {
188        let mut entries: Vec<_> = std::fs::read_dir(dir)?
189            .filter_map(std::result::Result::ok)
190            .collect();
191
192        // Sort for deterministic ordering
193        entries.sort_by_key(std::fs::DirEntry::file_name);
194
195        for entry in entries {
196            let path = entry.path();
197            let relative = path.strip_prefix(prefix).unwrap_or(&path);
198
199            // Hash the relative path
200            hasher.update(relative.to_string_lossy().as_bytes());
201
202            let metadata = entry.metadata()?;
203            if metadata.is_file() {
204                // Hash file size and contents
205                hasher.update(metadata.len().to_le_bytes());
206
207                let mut file = BufReader::new(File::open(&path)?);
208                let mut buffer = [0u8; 8192];
209                loop {
210                    let bytes_read = file.read(&mut buffer)?;
211                    if bytes_read == 0 {
212                        break;
213                    }
214                    hasher.update(&buffer[..bytes_read]);
215                }
216            } else if metadata.is_dir() {
217                hash_dir(hasher, &path, prefix)?;
218            }
219            // Skip symlinks and other special files for now
220        }
221
222        Ok(())
223    }
224
225    hash_dir(&mut hasher, dir, dir)?;
226    Ok(hex::encode(hasher.finalize()))
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use tempfile::TempDir;
233
234    #[test]
235    fn test_snapshot_roundtrip() {
236        let source = TempDir::new().unwrap();
237        let staging = TempDir::new().unwrap();
238        let target = TempDir::new().unwrap();
239
240        // Create some test files
241        std::fs::write(source.path().join("test.txt"), "hello world").unwrap();
242        std::fs::create_dir(source.path().join("subdir")).unwrap();
243        std::fs::write(source.path().join("subdir/nested.txt"), "nested content").unwrap();
244
245        // Create snapshot
246        let tarball_path = staging.path().join("layer.tar.zst");
247        let snapshot = create_snapshot(source.path(), &tarball_path, 3).unwrap();
248
249        assert!(!snapshot.digest.is_empty());
250        assert!(snapshot.size_bytes > 0);
251        assert!(snapshot.compressed_size_bytes > 0);
252        assert!(snapshot.file_count >= 2);
253
254        // Extract and verify
255        extract_snapshot(&tarball_path, target.path(), Some(&snapshot.digest)).unwrap();
256
257        assert_eq!(
258            std::fs::read_to_string(target.path().join("test.txt")).unwrap(),
259            "hello world"
260        );
261        assert_eq!(
262            std::fs::read_to_string(target.path().join("subdir/nested.txt")).unwrap(),
263            "nested content"
264        );
265    }
266
267    #[test]
268    fn test_directory_digest() {
269        let dir = TempDir::new().unwrap();
270
271        std::fs::write(dir.path().join("file1.txt"), "content1").unwrap();
272        std::fs::write(dir.path().join("file2.txt"), "content2").unwrap();
273
274        let digest1 = calculate_directory_digest(dir.path()).unwrap();
275
276        // Modify a file
277        std::fs::write(dir.path().join("file1.txt"), "modified").unwrap();
278
279        let digest2 = calculate_directory_digest(dir.path()).unwrap();
280
281        // Digests should differ
282        assert_ne!(digest1, digest2);
283    }
284}