streaming_crypto/core_api/compression/codecs/
zstd.rs1use std::io::{Cursor, BufReader};
14
15use crate::compression::{compute_checksum, types::{CompressionError, Compressor, Decompressor}, verify_checksum};
16
17pub struct ZstdCompressor {
21 _encoder: Option<zstd::stream::Encoder<'static, Vec<u8>>>, }
23
24pub struct ZstdDecompressor {
29 _buffer: Vec<u8>,
30 _decoder: Option<zstd::stream::Decoder<'static, BufReader<Cursor<Vec<u8>>>>>,
31}
32
33impl ZstdCompressor {
34 pub fn new(level: i32, dict: Option<&[u8]>) -> Result<Box<dyn Compressor + Send>, CompressionError> {
39 let encoder = if let Some(d) = dict {
40 zstd::stream::Encoder::with_dictionary(Vec::new(), level, d)
41 .map_err(|e| CompressionError::CodecInitFailed {
42 codec: "zstd".into(),
43 msg: e.to_string(),
44 })?
45 } else {
46 zstd::stream::Encoder::new(Vec::new(), level)
47 .map_err(|e| CompressionError::CodecInitFailed {
48 codec: "zstd".into(),
49 msg: e.to_string(),
50 })?
51 };
52 Ok(Box::new(Self { _encoder: Some(encoder) }))
53 }
54}
55
56impl Compressor for ZstdCompressor {
57 fn compress_chunk(&mut self, input: &[u8], out: &mut Vec<u8>) -> Result<(), CompressionError> {
58 let compressed = zstd::bulk::compress(input, 6)
60 .map_err(|e| CompressionError::CodecProcessFailed { codec: "zstd".into(), msg: e.to_string() })?;
61
62 let orig_len = input.len() as u32;
64 out.extend_from_slice(&orig_len.to_le_bytes());
65 out.extend_from_slice(&compressed);
66
67 let checksum = compute_checksum(input, None);
69 out.extend_from_slice(&checksum.to_le_bytes());
70
71 Ok(())
72 }
73
74 fn finish(&mut self, _out: &mut Vec<u8>) -> Result<(), CompressionError> {
75 Ok(())
76 }
77}
78
79impl ZstdDecompressor {
80 pub fn new(dict: Option<&[u8]>) -> Result<Box<dyn Decompressor + Send>, CompressionError> {
81 let cursor = Cursor::new(Vec::new());
82 let result: Result<zstd::stream::Decoder<'_, BufReader<Cursor<Vec<u8>>>>, std::io::Error> =
83 if let Some(d) = dict {
84 zstd::stream::Decoder::with_dictionary(BufReader::new(cursor), d)
85 } else {
86 zstd::stream::Decoder::new(cursor)
87 };
88
89 let decoder = match result {
90 Ok(dec) => Some(dec),
91 Err(e) => {
92 return Err(CompressionError::CodecInitFailed {
93 codec: "zstd".into(),
94 msg: e.to_string(),
95 });
96 }
97 };
98
99 Ok(Box::new(Self {
100 _buffer: Vec::new(),
101 _decoder: decoder,
102 }))
103 }
104}
105
106impl Decompressor for ZstdDecompressor {
107 fn decompress_chunk(&mut self, input: &[u8], out: &mut Vec<u8>) -> Result<(), CompressionError> {
108 if input.len() < 8 {
109 return Err(CompressionError::CodecProcessFailed {
110 codec: "zstd".into(),
111 msg: "input too short for length+checksum".into(),
112 });
113 }
114
115 let orig_len = u32::from_le_bytes(input[0..4].try_into().unwrap()) as usize;
117
118 let compressed = &input[4..input.len() - 4];
120 let checksum_bytes = &input[input.len() - 4..];
121 let expected_crc = u32::from_le_bytes(checksum_bytes.try_into().unwrap());
122
123 let decompressed = zstd::bulk::decompress(compressed, orig_len)
125 .map_err(|e| CompressionError::CodecProcessFailed { codec: "zstd".into(), msg: e.to_string() })?;
126
127 if decompressed.len() != orig_len {
129 return Err(CompressionError::CodecProcessFailed {
130 codec: "zstd".into(),
131 msg: format!("decoded size {} != prefix {}", decompressed.len(), orig_len),
132 });
133 }
134
135 let actual_crc = compute_checksum(&decompressed, None);
137 verify_checksum(expected_crc, actual_crc, "zstd".into())?;
138
139 out.extend_from_slice(&decompressed);
140 Ok(())
141 }
142}