Skip to main content

streaming_crypto/core_api/compression/codecs/
zstd.rs

1//! src/compression/codecs/zstd.rs
2//!
3//! Zstd streaming compressor/decompressor.
4//!
5//! Design notes:
6//! - Wraps zstd streaming encoder/decoder with trait objects for uniform pipeline use.
7//! - Errors are mapped into `CompressionError` variants with codec context.
8//! - Compressor accumulates into an internal Vec; finish consumes encoder safely.
9//! - Decompressor buffers input and reconstructs decoder per chunk for simplicity.
10
11// #### Option 1: Use Zstd block API
12// Zstd has a block compression API (`zstd::bulk::compress` / `decompress`) that produces standalone compressed blocks. Each block can be decompressed independently.
13use std::io::{Cursor, BufReader};
14
15use crate::compression::{compute_checksum, types::{CompressionError, Compressor, Decompressor}, verify_checksum};
16
17/// Zstd streaming compressor.
18/// - Holds an encoder writing into an internal Vec.
19/// - Implements `Compressor` trait for chunked compression.
20pub struct ZstdCompressor {
21    _encoder: Option<zstd::stream::Encoder<'static, Vec<u8>>>, // wrapped in Option to allow finish()
22}
23
24/// Zstd streaming decompressor.
25/// - Buffers compressed input.
26/// - Reconstructs decoder per chunk for simplicity.
27/// - Implements `Decompressor` trait.
28pub struct ZstdDecompressor {
29    _buffer: Vec<u8>,
30    _decoder: Option<zstd::stream::Decoder<'static, BufReader<Cursor<Vec<u8>>>>>,
31}
32
33impl ZstdCompressor {
34    /// Create a new Zstd compressor with given level and optional dictionary.
35    ///
36    /// # Errors
37    /// - Returns `CompressionError::CodecInitFailed` if encoder initialization fails.
38    pub fn new(level: i32, dict: Option<&[u8]>) -> Result<Box<dyn Compressor + Send>, CompressionError> {
39        let encoder = if let Some(d) = dict {
40            zstd::stream::Encoder::with_dictionary(Vec::new(), level, d)
41                .map_err(|e| CompressionError::CodecInitFailed {
42                    codec: "zstd".into(),
43                    msg: e.to_string(),
44                })?
45        } else {
46            zstd::stream::Encoder::new(Vec::new(), level)
47                .map_err(|e| CompressionError::CodecInitFailed {
48                    codec: "zstd".into(),
49                    msg: e.to_string(),
50                })?
51        };
52        Ok(Box::new(Self { _encoder: Some(encoder) }))
53    }
54}
55
56impl Compressor for ZstdCompressor {
57    fn compress_chunk(&mut self, input: &[u8], out: &mut Vec<u8>) -> Result<(), CompressionError> {
58        // Compress the input
59        let compressed = zstd::bulk::compress(input, 6)
60            .map_err(|e| CompressionError::CodecProcessFailed { codec: "zstd".into(), msg: e.to_string() })?;
61
62        // Prefix with original plaintext length (like lz4_flex does)
63        let orig_len = input.len() as u32;
64        out.extend_from_slice(&orig_len.to_le_bytes());
65        out.extend_from_slice(&compressed);
66
67        // Append CRC32 of original plaintext
68        let checksum = compute_checksum(input, None);
69        out.extend_from_slice(&checksum.to_le_bytes());
70        
71        Ok(())
72    }
73
74    fn finish(&mut self, _out: &mut Vec<u8>) -> Result<(), CompressionError> {
75        Ok(())
76    }
77}
78
79impl ZstdDecompressor {
80    pub fn new(dict: Option<&[u8]>) -> Result<Box<dyn Decompressor + Send>, CompressionError> {
81        let cursor = Cursor::new(Vec::new());
82        let result: Result<zstd::stream::Decoder<'_, BufReader<Cursor<Vec<u8>>>>, std::io::Error> =
83            if let Some(d) = dict {
84                zstd::stream::Decoder::with_dictionary(BufReader::new(cursor), d)
85            } else {
86                zstd::stream::Decoder::new(cursor)
87            };
88
89        let decoder = match result {
90            Ok(dec) => Some(dec),
91            Err(e) => {
92                return Err(CompressionError::CodecInitFailed {
93                    codec: "zstd".into(),
94                    msg: e.to_string(),
95                });
96            }
97        };
98
99        Ok(Box::new(Self {
100            _buffer: Vec::new(),
101            _decoder: decoder,
102        }))
103    }
104}
105
106impl Decompressor for ZstdDecompressor {
107    fn decompress_chunk(&mut self, input: &[u8], out: &mut Vec<u8>) -> Result<(), CompressionError> {
108        if input.len() < 8 {
109            return Err(CompressionError::CodecProcessFailed {
110                codec: "zstd".into(),
111                msg: "input too short for length+checksum".into(),
112            });
113        }
114
115        // Read original length prefix
116        let orig_len = u32::from_le_bytes(input[0..4].try_into().unwrap()) as usize;
117
118        // compressed data is everything except the last 4 bytes
119        let compressed = &input[4..input.len() - 4];
120        let checksum_bytes = &input[input.len() - 4..];
121        let expected_crc = u32::from_le_bytes(checksum_bytes.try_into().unwrap());
122
123        // Decompress with known output size
124        let decompressed = zstd::bulk::decompress(compressed, orig_len)
125            .map_err(|e| CompressionError::CodecProcessFailed { codec: "zstd".into(), msg: e.to_string() })?;
126
127        // Optional sanity check: verify decoded size matches prefix
128        if decompressed.len() != orig_len {
129            return Err(CompressionError::CodecProcessFailed {
130                codec: "zstd".into(),
131                msg: format!("decoded size {} != prefix {}", decompressed.len(), orig_len),
132            });
133        }
134
135        // Verify checksum
136        let actual_crc = compute_checksum(&decompressed, None);
137        verify_checksum(expected_crc, actual_crc, "zstd".into())?;
138
139        out.extend_from_slice(&decompressed);
140        Ok(())
141    }
142}