structured_zstd/encoding/
frame_compressor.rs1use alloc::vec::Vec;
4use core::convert::TryInto;
5#[cfg(feature = "hash")]
6use twox_hash::XxHash64;
7
8#[cfg(feature = "hash")]
9use core::hash::Hasher;
10
11use super::{
12 block_header::BlockHeader, frame_header::FrameHeader, levels::*,
13 match_generator::MatchGeneratorDriver, CompressionLevel, Matcher,
14};
15use crate::fse::fse_encoder::{default_ll_table, default_ml_table, default_of_table, FSETable};
16
17use crate::io::{Read, Write};
18
19pub struct FrameCompressor<R: Read, W: Write, M: Matcher> {
39 uncompressed_data: Option<R>,
40 compressed_data: Option<W>,
41 compression_level: CompressionLevel,
42 state: CompressState<M>,
43 #[cfg(feature = "hash")]
44 hasher: XxHash64,
45}
46
47pub(crate) struct FseTables {
48 pub(crate) ll_default: FSETable,
49 pub(crate) ll_previous: Option<FSETable>,
50 pub(crate) ml_default: FSETable,
51 pub(crate) ml_previous: Option<FSETable>,
52 pub(crate) of_default: FSETable,
53 pub(crate) of_previous: Option<FSETable>,
54}
55
56impl FseTables {
57 pub fn new() -> Self {
58 Self {
59 ll_default: default_ll_table(),
60 ll_previous: None,
61 ml_default: default_ml_table(),
62 ml_previous: None,
63 of_default: default_of_table(),
64 of_previous: None,
65 }
66 }
67}
68
69pub(crate) struct CompressState<M: Matcher> {
70 pub(crate) matcher: M,
71 pub(crate) last_huff_table: Option<crate::huff0::huff0_encoder::HuffmanTable>,
72 pub(crate) fse_tables: FseTables,
73}
74
75impl<R: Read, W: Write> FrameCompressor<R, W, MatchGeneratorDriver> {
76 pub fn new(compression_level: CompressionLevel) -> Self {
78 Self {
79 uncompressed_data: None,
80 compressed_data: None,
81 compression_level,
82 state: CompressState {
83 matcher: MatchGeneratorDriver::new(1024 * 128, 1),
84 last_huff_table: None,
85 fse_tables: FseTables::new(),
86 },
87 #[cfg(feature = "hash")]
88 hasher: XxHash64::with_seed(0),
89 }
90 }
91}
92
93impl<R: Read, W: Write, M: Matcher> FrameCompressor<R, W, M> {
94 pub fn new_with_matcher(matcher: M, compression_level: CompressionLevel) -> Self {
96 Self {
97 uncompressed_data: None,
98 compressed_data: None,
99 state: CompressState {
100 matcher,
101 last_huff_table: None,
102 fse_tables: FseTables::new(),
103 },
104 compression_level,
105 #[cfg(feature = "hash")]
106 hasher: XxHash64::with_seed(0),
107 }
108 }
109
110 pub fn set_source(&mut self, uncompressed_data: R) -> Option<R> {
114 self.uncompressed_data.replace(uncompressed_data)
115 }
116
117 pub fn set_drain(&mut self, compressed_data: W) -> Option<W> {
121 self.compressed_data.replace(compressed_data)
122 }
123
124 pub fn compress(&mut self) {
132 self.state.matcher.reset(self.compression_level);
134 self.state.last_huff_table = None;
135 #[cfg(feature = "hash")]
136 {
137 self.hasher = XxHash64::with_seed(0);
138 }
139 let source = self.uncompressed_data.as_mut().unwrap();
140 let drain = self.compressed_data.as_mut().unwrap();
141 let output: &mut Vec<u8> = &mut Vec::with_capacity(1024 * 130);
143 let header = FrameHeader {
145 frame_content_size: None,
146 single_segment: false,
147 content_checksum: cfg!(feature = "hash"),
148 dictionary_id: None,
149 window_size: Some(self.state.matcher.window_size()),
150 };
151 header.serialize(output);
152 loop {
154 let mut uncompressed_data = self.state.matcher.get_next_space();
156 let mut read_bytes = 0;
157 let last_block;
158 'read_loop: loop {
159 let new_bytes = source.read(&mut uncompressed_data[read_bytes..]).unwrap();
160 if new_bytes == 0 {
161 last_block = true;
162 break 'read_loop;
163 }
164 read_bytes += new_bytes;
165 if read_bytes == uncompressed_data.len() {
166 last_block = false;
167 break 'read_loop;
168 }
169 }
170 uncompressed_data.resize(read_bytes, 0);
171 #[cfg(feature = "hash")]
173 self.hasher.write(&uncompressed_data);
174 if uncompressed_data.is_empty() {
176 let header = BlockHeader {
177 last_block: true,
178 block_type: crate::blocks::block::BlockType::Raw,
179 block_size: 0,
180 };
181 header.serialize(output);
183 drain.write_all(output).unwrap();
184 output.clear();
185 break;
186 }
187
188 match self.compression_level {
189 CompressionLevel::Uncompressed => {
190 let header = BlockHeader {
191 last_block,
192 block_type: crate::blocks::block::BlockType::Raw,
193 block_size: read_bytes.try_into().unwrap(),
194 };
195 header.serialize(output);
197 output.extend_from_slice(&uncompressed_data);
198 }
199 CompressionLevel::Fastest => {
200 compress_fastest(&mut self.state, last_block, uncompressed_data, output)
201 }
202 _ => {
203 unimplemented!();
204 }
205 }
206 drain.write_all(output).unwrap();
207 output.clear();
208 if last_block {
209 break;
210 }
211 }
212
213 #[cfg(feature = "hash")]
216 {
217 let content_checksum = self.hasher.finish();
220 drain
221 .write_all(&(content_checksum as u32).to_le_bytes())
222 .unwrap();
223 }
224 }
225
226 pub fn source_mut(&mut self) -> Option<&mut R> {
228 self.uncompressed_data.as_mut()
229 }
230
231 pub fn drain_mut(&mut self) -> Option<&mut W> {
233 self.compressed_data.as_mut()
234 }
235
236 pub fn source(&self) -> Option<&R> {
238 self.uncompressed_data.as_ref()
239 }
240
241 pub fn drain(&self) -> Option<&W> {
243 self.compressed_data.as_ref()
244 }
245
246 pub fn take_source(&mut self) -> Option<R> {
248 self.uncompressed_data.take()
249 }
250
251 pub fn take_drain(&mut self) -> Option<W> {
253 self.compressed_data.take()
254 }
255
256 pub fn replace_matcher(&mut self, mut match_generator: M) -> M {
258 core::mem::swap(&mut match_generator, &mut self.state.matcher);
259 match_generator
260 }
261
262 pub fn set_compression_level(
264 &mut self,
265 compression_level: CompressionLevel,
266 ) -> CompressionLevel {
267 let old = self.compression_level;
268 self.compression_level = compression_level;
269 old
270 }
271
272 pub fn compression_level(&self) -> CompressionLevel {
274 self.compression_level
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use alloc::vec;
281
282 use super::FrameCompressor;
283 use crate::common::MAGIC_NUM;
284 use crate::decoding::FrameDecoder;
285 use alloc::vec::Vec;
286
287 #[test]
288 fn frame_starts_with_magic_num() {
289 let mock_data = [1_u8, 2, 3].as_slice();
290 let mut output: Vec<u8> = Vec::new();
291 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
292 compressor.set_source(mock_data);
293 compressor.set_drain(&mut output);
294
295 compressor.compress();
296 assert!(output.starts_with(&MAGIC_NUM.to_le_bytes()));
297 }
298
299 #[test]
300 fn very_simple_raw_compress() {
301 let mock_data = [1_u8, 2, 3].as_slice();
302 let mut output: Vec<u8> = Vec::new();
303 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
304 compressor.set_source(mock_data);
305 compressor.set_drain(&mut output);
306
307 compressor.compress();
308 }
309
310 #[test]
311 fn very_simple_compress() {
312 let mut mock_data = vec![0; 1 << 17];
313 mock_data.extend(vec![1; (1 << 17) - 1]);
314 mock_data.extend(vec![2; (1 << 18) - 1]);
315 mock_data.extend(vec![2; 1 << 17]);
316 mock_data.extend(vec![3; (1 << 17) - 1]);
317 let mut output: Vec<u8> = Vec::new();
318 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
319 compressor.set_source(mock_data.as_slice());
320 compressor.set_drain(&mut output);
321
322 compressor.compress();
323
324 let mut decoder = FrameDecoder::new();
325 let mut decoded = Vec::with_capacity(mock_data.len());
326 decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
327 assert_eq!(mock_data, decoded);
328
329 let mut decoded = Vec::new();
330 zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
331 assert_eq!(mock_data, decoded);
332 }
333
334 #[test]
335 fn rle_compress() {
336 let mock_data = vec![0; 1 << 19];
337 let mut output: Vec<u8> = Vec::new();
338 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
339 compressor.set_source(mock_data.as_slice());
340 compressor.set_drain(&mut output);
341
342 compressor.compress();
343
344 let mut decoder = FrameDecoder::new();
345 let mut decoded = Vec::with_capacity(mock_data.len());
346 decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
347 assert_eq!(mock_data, decoded);
348 }
349
350 #[test]
351 fn aaa_compress() {
352 let mock_data = vec![0, 1, 3, 4, 5];
353 let mut output: Vec<u8> = Vec::new();
354 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
355 compressor.set_source(mock_data.as_slice());
356 compressor.set_drain(&mut output);
357
358 compressor.compress();
359
360 let mut decoder = FrameDecoder::new();
361 let mut decoded = Vec::with_capacity(mock_data.len());
362 decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
363 assert_eq!(mock_data, decoded);
364
365 let mut decoded = Vec::new();
366 zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
367 assert_eq!(mock_data, decoded);
368 }
369
370 #[cfg(feature = "hash")]
371 #[test]
372 fn checksum_two_frames_reused_compressor() {
373 let data: Vec<u8> = (0u8..=255).cycle().take(1024).collect();
379
380 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
381
382 let mut compressed1 = Vec::new();
384 compressor.set_source(data.as_slice());
385 compressor.set_drain(&mut compressed1);
386 compressor.compress();
387
388 let mut compressed2 = Vec::new();
390 compressor.set_source(data.as_slice());
391 compressor.set_drain(&mut compressed2);
392 compressor.compress();
393
394 fn decode_and_collect(compressed: &[u8]) -> (Vec<u8>, Option<u32>, Option<u32>) {
395 let mut decoder = FrameDecoder::new();
396 let mut source = compressed;
397 decoder.reset(&mut source).unwrap();
398 while !decoder.is_finished() {
399 decoder
400 .decode_blocks(&mut source, crate::decoding::BlockDecodingStrategy::All)
401 .unwrap();
402 }
403 let mut decoded = Vec::new();
404 decoder.collect_to_writer(&mut decoded).unwrap();
405 (
406 decoded,
407 decoder.get_checksum_from_data(),
408 decoder.get_calculated_checksum(),
409 )
410 }
411
412 let (decoded1, chksum_from_data1, chksum_calculated1) = decode_and_collect(&compressed1);
413 assert_eq!(decoded1, data, "frame 1: decoded data mismatch");
414 assert_eq!(
415 chksum_from_data1, chksum_calculated1,
416 "frame 1: checksum mismatch"
417 );
418
419 let (decoded2, chksum_from_data2, chksum_calculated2) = decode_and_collect(&compressed2);
420 assert_eq!(decoded2, data, "frame 2: decoded data mismatch");
421 assert_eq!(
422 chksum_from_data2, chksum_calculated2,
423 "frame 2: checksum mismatch"
424 );
425
426 assert_eq!(
429 chksum_from_data1, chksum_from_data2,
430 "frame 1 and frame 2 should have the same checksum (same data, hash must reset per frame)"
431 );
432 }
433
434 #[cfg(feature = "std")]
435 #[test]
436 fn fuzz_targets() {
437 use std::io::Read;
438 fn decode_ruzstd(data: &mut dyn std::io::Read) -> Vec<u8> {
439 let mut decoder = crate::decoding::StreamingDecoder::new(data).unwrap();
440 let mut result: Vec<u8> = Vec::new();
441 decoder.read_to_end(&mut result).expect("Decoding failed");
442 result
443 }
444
445 fn decode_ruzstd_writer(mut data: impl Read) -> Vec<u8> {
446 let mut decoder = crate::decoding::FrameDecoder::new();
447 decoder.reset(&mut data).unwrap();
448 let mut result = vec![];
449 while !decoder.is_finished() || decoder.can_collect() > 0 {
450 decoder
451 .decode_blocks(
452 &mut data,
453 crate::decoding::BlockDecodingStrategy::UptoBytes(1024 * 1024),
454 )
455 .unwrap();
456 decoder.collect_to_writer(&mut result).unwrap();
457 }
458 result
459 }
460
461 fn encode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
462 zstd::stream::encode_all(std::io::Cursor::new(data), 3)
463 }
464
465 fn encode_ruzstd_uncompressed(data: &mut dyn std::io::Read) -> Vec<u8> {
466 let mut input = Vec::new();
467 data.read_to_end(&mut input).unwrap();
468
469 crate::encoding::compress_to_vec(
470 input.as_slice(),
471 crate::encoding::CompressionLevel::Uncompressed,
472 )
473 }
474
475 fn encode_ruzstd_compressed(data: &mut dyn std::io::Read) -> Vec<u8> {
476 let mut input = Vec::new();
477 data.read_to_end(&mut input).unwrap();
478
479 crate::encoding::compress_to_vec(
480 input.as_slice(),
481 crate::encoding::CompressionLevel::Fastest,
482 )
483 }
484
485 fn decode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
486 let mut output = Vec::new();
487 zstd::stream::copy_decode(data, &mut output)?;
488 Ok(output)
489 }
490 if std::fs::exists("fuzz/artifacts/interop").unwrap_or(false) {
491 for file in std::fs::read_dir("fuzz/artifacts/interop").unwrap() {
492 if file.as_ref().unwrap().file_type().unwrap().is_file() {
493 let data = std::fs::read(file.unwrap().path()).unwrap();
494 let data = data.as_slice();
495 let compressed = encode_zstd(data).unwrap();
497 let decoded = decode_ruzstd(&mut compressed.as_slice());
498 let decoded2 = decode_ruzstd_writer(&mut compressed.as_slice());
499 assert!(
500 decoded == data,
501 "Decoded data did not match the original input during decompression"
502 );
503 assert_eq!(
504 decoded2, data,
505 "Decoded data did not match the original input during decompression"
506 );
507
508 let mut input = data;
511 let compressed = encode_ruzstd_uncompressed(&mut input);
512 let decoded = decode_zstd(&compressed).unwrap();
513 assert_eq!(
514 decoded, data,
515 "Decoded data did not match the original input during compression"
516 );
517 let mut input = data;
519 let compressed = encode_ruzstd_compressed(&mut input);
520 let decoded = decode_zstd(&compressed).unwrap();
521 assert_eq!(
522 decoded, data,
523 "Decoded data did not match the original input during compression"
524 );
525 }
526 }
527 }
528 }
529}