structured_zstd/encoding/
frame_compressor.rs1use alloc::{boxed::Box, vec::Vec};
4use core::convert::TryInto;
5#[cfg(feature = "hash")]
6use twox_hash::XxHash64;
7
8#[cfg(feature = "hash")]
9use core::hash::Hasher;
10
11use super::{
12 CompressionLevel, Matcher, block_header::BlockHeader, frame_header::FrameHeader, levels::*,
13 match_generator::MatchGeneratorDriver,
14};
15use crate::fse::fse_encoder::{FSETable, default_ll_table, default_ml_table, default_of_table};
16
17use crate::io::{Read, Write};
18
19pub struct FrameCompressor<R: Read, W: Write, M: Matcher> {
39 uncompressed_data: Option<R>,
40 compressed_data: Option<W>,
41 compression_level: CompressionLevel,
42 state: CompressState<M>,
43 #[cfg(feature = "hash")]
44 hasher: XxHash64,
45}
46
47#[derive(Clone)]
48pub(crate) enum PreviousFseTable {
49 Default,
52 Custom(Box<FSETable>),
53}
54
55impl PreviousFseTable {
56 pub(crate) fn as_table<'a>(&'a self, default: &'a FSETable) -> &'a FSETable {
57 match self {
58 Self::Default => default,
59 Self::Custom(table) => table,
60 }
61 }
62}
63
64pub(crate) struct FseTables {
65 pub(crate) ll_default: FSETable,
66 pub(crate) ll_previous: Option<PreviousFseTable>,
67 pub(crate) ml_default: FSETable,
68 pub(crate) ml_previous: Option<PreviousFseTable>,
69 pub(crate) of_default: FSETable,
70 pub(crate) of_previous: Option<PreviousFseTable>,
71}
72
73impl FseTables {
74 pub fn new() -> Self {
75 Self {
76 ll_default: default_ll_table(),
77 ll_previous: None,
78 ml_default: default_ml_table(),
79 ml_previous: None,
80 of_default: default_of_table(),
81 of_previous: None,
82 }
83 }
84}
85
86pub(crate) struct CompressState<M: Matcher> {
87 pub(crate) matcher: M,
88 pub(crate) last_huff_table: Option<crate::huff0::huff0_encoder::HuffmanTable>,
89 pub(crate) fse_tables: FseTables,
90 pub(crate) offset_hist: [u32; 3],
93}
94
95impl<R: Read, W: Write> FrameCompressor<R, W, MatchGeneratorDriver> {
96 pub fn new(compression_level: CompressionLevel) -> Self {
98 Self {
99 uncompressed_data: None,
100 compressed_data: None,
101 compression_level,
102 state: CompressState {
103 matcher: MatchGeneratorDriver::new(1024 * 128, 1),
104 last_huff_table: None,
105 fse_tables: FseTables::new(),
106 offset_hist: [1, 4, 8],
107 },
108 #[cfg(feature = "hash")]
109 hasher: XxHash64::with_seed(0),
110 }
111 }
112}
113
114impl<R: Read, W: Write, M: Matcher> FrameCompressor<R, W, M> {
115 pub fn new_with_matcher(matcher: M, compression_level: CompressionLevel) -> Self {
117 Self {
118 uncompressed_data: None,
119 compressed_data: None,
120 state: CompressState {
121 matcher,
122 last_huff_table: None,
123 fse_tables: FseTables::new(),
124 offset_hist: [1, 4, 8],
125 },
126 compression_level,
127 #[cfg(feature = "hash")]
128 hasher: XxHash64::with_seed(0),
129 }
130 }
131
132 pub fn set_source(&mut self, uncompressed_data: R) -> Option<R> {
136 self.uncompressed_data.replace(uncompressed_data)
137 }
138
139 pub fn set_drain(&mut self, compressed_data: W) -> Option<W> {
143 self.compressed_data.replace(compressed_data)
144 }
145
146 pub fn compress(&mut self) {
154 self.state.matcher.reset(self.compression_level);
156 self.state.last_huff_table = None;
157 self.state.fse_tables.ll_previous = None;
158 self.state.fse_tables.ml_previous = None;
159 self.state.fse_tables.of_previous = None;
160 self.state.offset_hist = [1, 4, 8];
161 #[cfg(feature = "hash")]
162 {
163 self.hasher = XxHash64::with_seed(0);
164 }
165 let source = self.uncompressed_data.as_mut().unwrap();
166 let drain = self.compressed_data.as_mut().unwrap();
167 let output: &mut Vec<u8> = &mut Vec::with_capacity(1024 * 130);
169 let header = FrameHeader {
171 frame_content_size: None,
172 single_segment: false,
173 content_checksum: cfg!(feature = "hash"),
174 dictionary_id: None,
175 window_size: Some(self.state.matcher.window_size()),
176 };
177 header.serialize(output);
178 loop {
180 let mut uncompressed_data = self.state.matcher.get_next_space();
182 let mut read_bytes = 0;
183 let last_block;
184 'read_loop: loop {
185 let new_bytes = source.read(&mut uncompressed_data[read_bytes..]).unwrap();
186 if new_bytes == 0 {
187 last_block = true;
188 break 'read_loop;
189 }
190 read_bytes += new_bytes;
191 if read_bytes == uncompressed_data.len() {
192 last_block = false;
193 break 'read_loop;
194 }
195 }
196 uncompressed_data.resize(read_bytes, 0);
197 #[cfg(feature = "hash")]
199 self.hasher.write(&uncompressed_data);
200 if uncompressed_data.is_empty() {
202 let header = BlockHeader {
203 last_block: true,
204 block_type: crate::blocks::block::BlockType::Raw,
205 block_size: 0,
206 };
207 header.serialize(output);
209 drain.write_all(output).unwrap();
210 output.clear();
211 break;
212 }
213
214 match self.compression_level {
215 CompressionLevel::Uncompressed => {
216 let header = BlockHeader {
217 last_block,
218 block_type: crate::blocks::block::BlockType::Raw,
219 block_size: read_bytes.try_into().unwrap(),
220 };
221 header.serialize(output);
223 output.extend_from_slice(&uncompressed_data);
224 }
225 CompressionLevel::Fastest | CompressionLevel::Default => {
226 compress_fastest(&mut self.state, last_block, uncompressed_data, output)
229 }
230 _ => {
231 unimplemented!();
232 }
233 }
234 drain.write_all(output).unwrap();
235 output.clear();
236 if last_block {
237 break;
238 }
239 }
240
241 #[cfg(feature = "hash")]
244 {
245 let content_checksum = self.hasher.finish();
248 drain
249 .write_all(&(content_checksum as u32).to_le_bytes())
250 .unwrap();
251 }
252 }
253
254 pub fn source_mut(&mut self) -> Option<&mut R> {
256 self.uncompressed_data.as_mut()
257 }
258
259 pub fn drain_mut(&mut self) -> Option<&mut W> {
261 self.compressed_data.as_mut()
262 }
263
264 pub fn source(&self) -> Option<&R> {
266 self.uncompressed_data.as_ref()
267 }
268
269 pub fn drain(&self) -> Option<&W> {
271 self.compressed_data.as_ref()
272 }
273
274 pub fn take_source(&mut self) -> Option<R> {
276 self.uncompressed_data.take()
277 }
278
279 pub fn take_drain(&mut self) -> Option<W> {
281 self.compressed_data.take()
282 }
283
284 pub fn replace_matcher(&mut self, mut match_generator: M) -> M {
286 core::mem::swap(&mut match_generator, &mut self.state.matcher);
287 match_generator
288 }
289
290 pub fn set_compression_level(
292 &mut self,
293 compression_level: CompressionLevel,
294 ) -> CompressionLevel {
295 let old = self.compression_level;
296 self.compression_level = compression_level;
297 old
298 }
299
300 pub fn compression_level(&self) -> CompressionLevel {
302 self.compression_level
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use alloc::vec;
309
310 use super::FrameCompressor;
311 use crate::common::MAGIC_NUM;
312 use crate::decoding::FrameDecoder;
313 use alloc::vec::Vec;
314
315 #[test]
316 fn frame_starts_with_magic_num() {
317 let mock_data = [1_u8, 2, 3].as_slice();
318 let mut output: Vec<u8> = Vec::new();
319 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
320 compressor.set_source(mock_data);
321 compressor.set_drain(&mut output);
322
323 compressor.compress();
324 assert!(output.starts_with(&MAGIC_NUM.to_le_bytes()));
325 }
326
327 #[test]
328 fn very_simple_raw_compress() {
329 let mock_data = [1_u8, 2, 3].as_slice();
330 let mut output: Vec<u8> = Vec::new();
331 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
332 compressor.set_source(mock_data);
333 compressor.set_drain(&mut output);
334
335 compressor.compress();
336 }
337
338 #[test]
339 fn very_simple_compress() {
340 let mut mock_data = vec![0; 1 << 17];
341 mock_data.extend(vec![1; (1 << 17) - 1]);
342 mock_data.extend(vec![2; (1 << 18) - 1]);
343 mock_data.extend(vec![2; 1 << 17]);
344 mock_data.extend(vec![3; (1 << 17) - 1]);
345 let mut output: Vec<u8> = Vec::new();
346 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
347 compressor.set_source(mock_data.as_slice());
348 compressor.set_drain(&mut output);
349
350 compressor.compress();
351
352 let mut decoder = FrameDecoder::new();
353 let mut decoded = Vec::with_capacity(mock_data.len());
354 decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
355 assert_eq!(mock_data, decoded);
356
357 let mut decoded = Vec::new();
358 zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
359 assert_eq!(mock_data, decoded);
360 }
361
362 #[test]
363 fn rle_compress() {
364 let mock_data = vec![0; 1 << 19];
365 let mut output: Vec<u8> = Vec::new();
366 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
367 compressor.set_source(mock_data.as_slice());
368 compressor.set_drain(&mut output);
369
370 compressor.compress();
371
372 let mut decoder = FrameDecoder::new();
373 let mut decoded = Vec::with_capacity(mock_data.len());
374 decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
375 assert_eq!(mock_data, decoded);
376 }
377
378 #[test]
379 fn aaa_compress() {
380 let mock_data = vec![0, 1, 3, 4, 5];
381 let mut output: Vec<u8> = Vec::new();
382 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
383 compressor.set_source(mock_data.as_slice());
384 compressor.set_drain(&mut output);
385
386 compressor.compress();
387
388 let mut decoder = FrameDecoder::new();
389 let mut decoded = Vec::with_capacity(mock_data.len());
390 decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
391 assert_eq!(mock_data, decoded);
392
393 let mut decoded = Vec::new();
394 zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
395 assert_eq!(mock_data, decoded);
396 }
397
398 #[cfg(feature = "hash")]
399 #[test]
400 fn checksum_two_frames_reused_compressor() {
401 let data: Vec<u8> = (0u8..=255).cycle().take(1024).collect();
407
408 let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
409
410 let mut compressed1 = Vec::new();
412 compressor.set_source(data.as_slice());
413 compressor.set_drain(&mut compressed1);
414 compressor.compress();
415
416 let mut compressed2 = Vec::new();
418 compressor.set_source(data.as_slice());
419 compressor.set_drain(&mut compressed2);
420 compressor.compress();
421
422 fn decode_and_collect(compressed: &[u8]) -> (Vec<u8>, Option<u32>, Option<u32>) {
423 let mut decoder = FrameDecoder::new();
424 let mut source = compressed;
425 decoder.reset(&mut source).unwrap();
426 while !decoder.is_finished() {
427 decoder
428 .decode_blocks(&mut source, crate::decoding::BlockDecodingStrategy::All)
429 .unwrap();
430 }
431 let mut decoded = Vec::new();
432 decoder.collect_to_writer(&mut decoded).unwrap();
433 (
434 decoded,
435 decoder.get_checksum_from_data(),
436 decoder.get_calculated_checksum(),
437 )
438 }
439
440 let (decoded1, chksum_from_data1, chksum_calculated1) = decode_and_collect(&compressed1);
441 assert_eq!(decoded1, data, "frame 1: decoded data mismatch");
442 assert_eq!(
443 chksum_from_data1, chksum_calculated1,
444 "frame 1: checksum mismatch"
445 );
446
447 let (decoded2, chksum_from_data2, chksum_calculated2) = decode_and_collect(&compressed2);
448 assert_eq!(decoded2, data, "frame 2: decoded data mismatch");
449 assert_eq!(
450 chksum_from_data2, chksum_calculated2,
451 "frame 2: checksum mismatch"
452 );
453
454 assert_eq!(
457 chksum_from_data1, chksum_from_data2,
458 "frame 1 and frame 2 should have the same checksum (same data, hash must reset per frame)"
459 );
460 }
461
462 #[cfg(feature = "std")]
463 #[test]
464 fn fuzz_targets() {
465 use std::io::Read;
466 fn decode_szstd(data: &mut dyn std::io::Read) -> Vec<u8> {
467 let mut decoder = crate::decoding::StreamingDecoder::new(data).unwrap();
468 let mut result: Vec<u8> = Vec::new();
469 decoder.read_to_end(&mut result).expect("Decoding failed");
470 result
471 }
472
473 fn decode_szstd_writer(mut data: impl Read) -> Vec<u8> {
474 let mut decoder = crate::decoding::FrameDecoder::new();
475 decoder.reset(&mut data).unwrap();
476 let mut result = vec![];
477 while !decoder.is_finished() || decoder.can_collect() > 0 {
478 decoder
479 .decode_blocks(
480 &mut data,
481 crate::decoding::BlockDecodingStrategy::UptoBytes(1024 * 1024),
482 )
483 .unwrap();
484 decoder.collect_to_writer(&mut result).unwrap();
485 }
486 result
487 }
488
489 fn encode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
490 zstd::stream::encode_all(std::io::Cursor::new(data), 3)
491 }
492
493 fn encode_szstd_uncompressed(data: &mut dyn std::io::Read) -> Vec<u8> {
494 let mut input = Vec::new();
495 data.read_to_end(&mut input).unwrap();
496
497 crate::encoding::compress_to_vec(
498 input.as_slice(),
499 crate::encoding::CompressionLevel::Uncompressed,
500 )
501 }
502
503 fn encode_szstd_compressed(data: &mut dyn std::io::Read) -> Vec<u8> {
504 let mut input = Vec::new();
505 data.read_to_end(&mut input).unwrap();
506
507 crate::encoding::compress_to_vec(
508 input.as_slice(),
509 crate::encoding::CompressionLevel::Fastest,
510 )
511 }
512
513 fn decode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
514 let mut output = Vec::new();
515 zstd::stream::copy_decode(data, &mut output)?;
516 Ok(output)
517 }
518 if std::fs::exists("fuzz/artifacts/interop").unwrap_or(false) {
519 for file in std::fs::read_dir("fuzz/artifacts/interop").unwrap() {
520 if file.as_ref().unwrap().file_type().unwrap().is_file() {
521 let data = std::fs::read(file.unwrap().path()).unwrap();
522 let data = data.as_slice();
523 let compressed = encode_zstd(data).unwrap();
525 let decoded = decode_szstd(&mut compressed.as_slice());
526 let decoded2 = decode_szstd_writer(&mut compressed.as_slice());
527 assert!(
528 decoded == data,
529 "Decoded data did not match the original input during decompression"
530 );
531 assert_eq!(
532 decoded2, data,
533 "Decoded data did not match the original input during decompression"
534 );
535
536 let mut input = data;
539 let compressed = encode_szstd_uncompressed(&mut input);
540 let decoded = decode_zstd(&compressed).unwrap();
541 assert_eq!(
542 decoded, data,
543 "Decoded data did not match the original input during compression"
544 );
545 let mut input = data;
547 let compressed = encode_szstd_compressed(&mut input);
548 let decoded = decode_zstd(&compressed).unwrap();
549 assert_eq!(
550 decoded, data,
551 "Decoded data did not match the original input during compression"
552 );
553 }
554 }
555 }
556 }
557}