Skip to main content

zrip_encode/
streaming.rs

1#![forbid(unsafe_code)]
2
3use std::io::{self, Write};
4
5use crate::block_encoder::{self, BlockEncodeWorkspace};
6use crate::dfast;
7use crate::fast;
8use crate::strategy::{self, LevelParams, Strategy};
9use zrip_core::Sequence;
10use zrip_core::dict::Dictionary;
11use zrip_core::error::CompressError;
12use zrip_core::frame::{MAX_BLOCK_SIZE, ZSTD_MAGIC};
13use zrip_core::xxhash::Xxh64State;
14
15/// Streaming zstd compressor implementing [`Write`].
16///
17/// Buffers input until a full block (128 KiB) is ready, then compresses
18/// and writes it to the underlying writer. Call [`finish`](Self::finish)
19/// to flush the final block, write the content checksum, and recover the
20/// writer.
21///
22/// Internal buffers (hash tables, sequence scratch, block encoder workspace)
23/// are allocated once and reused across blocks. To reuse them across
24/// multiple frames, call [`reset`](Self::reset) instead of `finish`:
25///
26/// ```
27/// use std::io::Write;
28///
29/// let mut encoder = zrip::FrameEncoder::new(Vec::new(), 1).unwrap();
30/// encoder.write_all(b"first frame").unwrap();
31/// let first = encoder.reset(Vec::new()).unwrap();   // reuses buffers
32/// encoder.write_all(b"second frame").unwrap();
33/// let second = encoder.finish().unwrap();
34/// ```
35pub struct FrameEncoder<W: Write> {
36    inner: W,
37    params: LevelParams,
38    buffer: Vec<u8>,
39    rep_offsets: [u32; 3],
40    hasher: Xxh64State,
41    header_written: bool,
42    finished: bool,
43    workspace: BlockEncodeWorkspace,
44    dict: Option<Dictionary>,
45    first_block: bool,
46    hash_table: Vec<u32>,
47    hash_long: Vec<u32>,
48    sequences: Vec<Sequence>,
49    combined: Vec<u8>,
50    block_out: Vec<u8>,
51}
52
53impl<W: Write> FrameEncoder<W> {
54    /// Creates a new streaming encoder at the given level (-7..=4).
55    pub fn new(writer: W, level: i32) -> Result<Self, CompressError> {
56        let params = strategy::level_params(level).ok_or(CompressError::InvalidLevel(level))?;
57        let (hash_table, hash_long) = alloc_hash_tables(&params);
58        Ok(Self {
59            inner: writer,
60            params,
61            buffer: Vec::new(),
62            rep_offsets: [1, 4, 8],
63            hasher: Xxh64State::new(0),
64            header_written: false,
65            finished: false,
66            workspace: BlockEncodeWorkspace::new(),
67            dict: None,
68            first_block: false,
69            hash_table,
70            hash_long,
71            sequences: Vec::new(),
72            combined: Vec::new(),
73            block_out: Vec::new(),
74        })
75    }
76
77    /// Creates a new streaming encoder with a dictionary at the given level (-7..=4).
78    pub fn with_dict(writer: W, level: i32, dict: Dictionary) -> Result<Self, CompressError> {
79        let params = strategy::level_params(level).ok_or(CompressError::InvalidLevel(level))?;
80        let (hash_table, hash_long) = alloc_hash_tables(&params);
81        let rep_offsets = *dict.rep_offsets();
82        Ok(Self {
83            inner: writer,
84            params,
85            buffer: Vec::new(),
86            rep_offsets,
87            hasher: Xxh64State::new(0),
88            header_written: false,
89            finished: false,
90            workspace: BlockEncodeWorkspace::new(),
91            dict: Some(dict),
92            first_block: true,
93            hash_table,
94            hash_long,
95            sequences: Vec::new(),
96            combined: Vec::new(),
97            block_out: Vec::new(),
98        })
99    }
100
101    /// Flushes remaining data, writes the content checksum, and returns the inner writer.
102    pub fn finish(mut self) -> Result<W, io::Error> {
103        self.finish_frame()?;
104        Ok(self.inner)
105    }
106
107    /// Finishes the current frame and installs `new_writer` for the next one.
108    ///
109    /// Returns the previous writer containing the completed frame. All
110    /// internal buffers (hash tables, workspace, block scratch) stay
111    /// allocated and are reused for the next frame.
112    pub fn reset(&mut self, new_writer: W) -> Result<W, io::Error> {
113        self.finish_frame()?;
114        let old = core::mem::replace(&mut self.inner, new_writer);
115        self.header_written = false;
116        self.finished = false;
117        self.first_block = self.dict.is_some();
118        self.rep_offsets = match &self.dict {
119            Some(d) => *d.rep_offsets(),
120            None => [1, 4, 8],
121        };
122        self.hasher = Xxh64State::new(0);
123        self.workspace.prev_huffman = None;
124        Ok(old)
125    }
126
127    fn finish_frame(&mut self) -> io::Result<()> {
128        if self.finished {
129            return Ok(());
130        }
131        self.finished = true;
132
133        if !self.header_written {
134            self.write_header()?;
135        }
136
137        self.flush_block(true)?;
138
139        let hash = self.hasher.finish();
140        let checksum = (hash & 0xFFFF_FFFF) as u32;
141        self.inner.write_all(&checksum.to_le_bytes())?;
142        Ok(())
143    }
144
145    fn write_header(&mut self) -> io::Result<()> {
146        self.header_written = true;
147
148        self.inner.write_all(&ZSTD_MAGIC.to_le_bytes())?;
149
150        let window_log = self.params.window_log;
151
152        let dict_id_flag = if let Some(ref dict) = self.dict {
153            let id = dict.id();
154            if id <= 0xFF {
155                1u8
156            } else if id <= 0xFFFF {
157                2
158            } else {
159                3
160            }
161        } else {
162            0
163        };
164
165        let descriptor = 0x04u8 | dict_id_flag;
166        self.inner.write_all(&[descriptor])?;
167
168        let mantissa = 0u8;
169        let exponent = (window_log - 10) as u8;
170        let window_descriptor = (exponent << 3) | mantissa;
171        self.inner.write_all(&[window_descriptor])?;
172
173        if let Some(ref dict) = self.dict {
174            let id = dict.id();
175            match dict_id_flag {
176                1 => self.inner.write_all(&[id as u8])?,
177                2 => self.inner.write_all(&(id as u16).to_le_bytes())?,
178                3 => self.inner.write_all(&id.to_le_bytes())?,
179                _ => unreachable!(),
180            }
181        }
182
183        Ok(())
184    }
185
186    fn flush_block(&mut self, last: bool) -> io::Result<()> {
187        if self.buffer.is_empty() && last {
188            self.block_out.clear();
189            block_encoder::encode_raw_block(&[], true, &mut self.block_out);
190            self.inner.write_all(&self.block_out)?;
191            return Ok(());
192        }
193
194        if self.buffer.is_empty() {
195            return Ok(());
196        }
197
198        let chunk = core::mem::take(&mut self.buffer);
199
200        self.block_out.clear();
201        self.block_out.reserve(chunk.len() + 32);
202        if crate::block_looks_incompressible(&chunk) {
203            block_encoder::encode_raw_block(&chunk, last, &mut self.block_out);
204        } else {
205            let use_prefix = self.first_block && self.dict.is_some();
206            if use_prefix {
207                let prefix = self.dict.as_ref().unwrap().content();
208                match self.params.strategy {
209                    Strategy::Fast => {
210                        fast::compress_fast_with_prefix_reuse(
211                            &chunk,
212                            &self.params,
213                            &self.rep_offsets,
214                            prefix,
215                            &mut self.hash_table,
216                            &mut self.sequences,
217                            &mut self.combined,
218                        );
219                    }
220                    Strategy::DFast => {
221                        dfast::compress_dfast_with_prefix_reuse(
222                            &chunk,
223                            &self.params,
224                            &self.rep_offsets,
225                            prefix,
226                            &mut self.hash_table,
227                            &mut self.hash_long,
228                            &mut self.sequences,
229                            &mut self.combined,
230                        );
231                    }
232                }
233            } else {
234                self.hash_table.fill(0);
235                if !self.hash_long.is_empty() {
236                    self.hash_long.fill(0);
237                }
238                match self.params.strategy {
239                    Strategy::Fast => {
240                        fast::compress_fast_block(
241                            &chunk,
242                            0,
243                            chunk.len(),
244                            &self.params,
245                            &self.rep_offsets,
246                            &mut self.hash_table,
247                            &mut self.sequences,
248                        );
249                    }
250                    Strategy::DFast => {
251                        dfast::compress_dfast_block(
252                            &chunk,
253                            0,
254                            chunk.len(),
255                            &self.params,
256                            &self.rep_offsets,
257                            &mut self.hash_table,
258                            &mut self.hash_long,
259                            &mut self.sequences,
260                        );
261                    }
262                }
263            }
264            if self.params.force_raw_literals {
265                block_encoder::encode_compressed_block_raw(
266                    &chunk,
267                    &self.sequences,
268                    &mut self.rep_offsets,
269                    last,
270                    &mut self.block_out,
271                    &mut self.workspace,
272                );
273            } else {
274                block_encoder::encode_compressed_block(
275                    &chunk,
276                    &self.sequences,
277                    &mut self.rep_offsets,
278                    last,
279                    &mut self.block_out,
280                    &mut self.workspace,
281                );
282            }
283        }
284
285        self.first_block = false;
286        self.inner.write_all(&self.block_out)?;
287        Ok(())
288    }
289}
290
291fn alloc_hash_tables(params: &LevelParams) -> (Vec<u32>, Vec<u32>) {
292    match params.strategy {
293        Strategy::Fast => (vec![0u32; 1usize << params.hash_log], Vec::new()),
294        Strategy::DFast => (
295            vec![0u32; 1usize << params.chain_log],
296            vec![0u32; 1usize << params.hash_log],
297        ),
298    }
299}
300
301impl<W: Write> Write for FrameEncoder<W> {
302    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
303        if self.finished {
304            return Err(io::Error::other("encoder already finished"));
305        }
306
307        if !self.header_written {
308            self.write_header()?;
309        }
310
311        self.hasher.update(buf);
312
313        let mut consumed = 0;
314        while consumed < buf.len() {
315            let space = MAX_BLOCK_SIZE - self.buffer.len();
316            let n = space.min(buf.len() - consumed);
317            self.buffer.extend_from_slice(&buf[consumed..consumed + n]);
318            consumed += n;
319
320            if self.buffer.len() >= MAX_BLOCK_SIZE {
321                self.flush_block(false)?;
322            }
323        }
324
325        Ok(consumed)
326    }
327
328    fn flush(&mut self) -> io::Result<()> {
329        if !self.buffer.is_empty() {
330            self.flush_block(false)?;
331        }
332        self.inner.flush()
333    }
334}