Skip to main content

zrip_encode/
streaming.rs

1#![forbid(unsafe_code)]
2
3use std::io::{self, Write};
4
5use crate::block_encoder::{self, BlockEncodeWorkspace};
6use crate::dfast;
7use crate::fast;
8use crate::strategy::{self, LevelParams, Strategy};
9use zrip_core::Sequence;
10use zrip_core::dict::Dictionary;
11use zrip_core::error::CompressError;
12use zrip_core::frame::{MAX_BLOCK_SIZE, ZSTD_MAGIC};
13use zrip_core::xxhash::Xxh64State;
14
15/// Streaming zstd compressor implementing [`Write`].
16///
17/// Buffers input until a full block (128 KiB) is ready, then compresses
18/// and writes it to the underlying writer. Call [`finish`](Self::finish)
19/// to flush the final block, write the content checksum, and recover the
20/// writer.
21///
22/// Internal buffers (hash tables, sequence scratch, block encoder workspace)
23/// are allocated once and reused across blocks. To reuse them across
24/// multiple frames, call [`reset`](Self::reset) instead of `finish`:
25///
26/// ```
27/// use std::io::Write;
28///
29/// let mut encoder = zrip::FrameEncoder::new(Vec::new(), 1).unwrap();
30/// encoder.write_all(b"first frame").unwrap();
31/// let first = encoder.reset(Vec::new()).unwrap();   // reuses buffers
32/// encoder.write_all(b"second frame").unwrap();
33/// let second = encoder.finish().unwrap();
34/// ```
35pub struct FrameEncoder<W: Write> {
36    inner: W,
37    params: LevelParams,
38    buffer: Vec<u8>,
39    rep_offsets: [u32; 3],
40    hasher: Xxh64State,
41    header_written: bool,
42    finished: bool,
43    workspace: BlockEncodeWorkspace,
44    dict: Option<Dictionary>,
45    first_block: bool,
46    hash_table: Vec<u32>,
47    hash_long: Vec<u32>,
48    dict_hash: Vec<u32>,
49    sequences: Vec<Sequence>,
50    combined: Vec<u8>,
51    block_out: Vec<u8>,
52}
53
54impl<W: Write> FrameEncoder<W> {
55    /// Creates a new streaming encoder at the given level (-7..=4).
56    pub fn new(writer: W, level: i32) -> Result<Self, CompressError> {
57        let params = strategy::level_params(level).ok_or(CompressError::InvalidLevel(level))?;
58        let (hash_table, hash_long) = alloc_hash_tables(&params);
59        Ok(Self {
60            inner: writer,
61            params,
62            buffer: Vec::new(),
63            rep_offsets: [1, 4, 8],
64            hasher: Xxh64State::new(0),
65            header_written: false,
66            finished: false,
67            workspace: BlockEncodeWorkspace::new(),
68            dict: None,
69            first_block: false,
70            hash_table,
71            hash_long,
72            dict_hash: Vec::new(),
73            sequences: Vec::new(),
74            combined: Vec::new(),
75            block_out: Vec::new(),
76        })
77    }
78
79    /// Creates a new streaming encoder with a dictionary at the given level (-7..=4).
80    pub fn with_dict(writer: W, level: i32, dict: Dictionary) -> Result<Self, CompressError> {
81        let params = strategy::level_params(level).ok_or(CompressError::InvalidLevel(level))?;
82        let (hash_table, hash_long) = alloc_hash_tables(&params);
83        let dict_hash = vec![0u32; hash_table.len()];
84        let rep_offsets = *dict.rep_offsets();
85        Ok(Self {
86            inner: writer,
87            params,
88            buffer: Vec::new(),
89            rep_offsets,
90            hasher: Xxh64State::new(0),
91            header_written: false,
92            finished: false,
93            workspace: BlockEncodeWorkspace::new(),
94            dict: Some(dict),
95            first_block: true,
96            hash_table,
97            hash_long,
98            dict_hash,
99            sequences: Vec::new(),
100            combined: Vec::new(),
101            block_out: Vec::new(),
102        })
103    }
104
105    /// Flushes remaining data, writes the content checksum, and returns the inner writer.
106    pub fn finish(mut self) -> Result<W, io::Error> {
107        self.finish_frame()?;
108        Ok(self.inner)
109    }
110
111    /// Finishes the current frame and installs `new_writer` for the next one.
112    ///
113    /// Returns the previous writer containing the completed frame. All
114    /// internal buffers (hash tables, workspace, block scratch) stay
115    /// allocated and are reused for the next frame.
116    pub fn reset(&mut self, new_writer: W) -> Result<W, io::Error> {
117        self.finish_frame()?;
118        let old = core::mem::replace(&mut self.inner, new_writer);
119        self.header_written = false;
120        self.finished = false;
121        self.first_block = self.dict.is_some();
122        self.rep_offsets = match &self.dict {
123            Some(d) => *d.rep_offsets(),
124            None => [1, 4, 8],
125        };
126        self.hasher = Xxh64State::new(0);
127        self.workspace.prev_huffman = None;
128        Ok(old)
129    }
130
131    fn finish_frame(&mut self) -> io::Result<()> {
132        if self.finished {
133            return Ok(());
134        }
135        self.finished = true;
136
137        if !self.header_written {
138            self.write_header()?;
139        }
140
141        self.flush_block(true)?;
142
143        let hash = self.hasher.finish();
144        let checksum = (hash & 0xFFFF_FFFF) as u32;
145        self.inner.write_all(&checksum.to_le_bytes())?;
146        Ok(())
147    }
148
149    fn write_header(&mut self) -> io::Result<()> {
150        self.header_written = true;
151
152        self.inner.write_all(&ZSTD_MAGIC.to_le_bytes())?;
153
154        let window_log = self.params.window_log;
155
156        let dict_id_flag = if let Some(ref dict) = self.dict {
157            let id = dict.id();
158            if id <= 0xFF {
159                1u8
160            } else if id <= 0xFFFF {
161                2
162            } else {
163                3
164            }
165        } else {
166            0
167        };
168
169        let descriptor = 0x04u8 | dict_id_flag;
170        self.inner.write_all(&[descriptor])?;
171
172        let mantissa = 0u8;
173        let exponent = (window_log - 10) as u8;
174        let window_descriptor = (exponent << 3) | mantissa;
175        self.inner.write_all(&[window_descriptor])?;
176
177        if let Some(ref dict) = self.dict {
178            let id = dict.id();
179            match dict_id_flag {
180                1 => self.inner.write_all(&[id as u8])?,
181                2 => self.inner.write_all(&(id as u16).to_le_bytes())?,
182                3 => self.inner.write_all(&id.to_le_bytes())?,
183                _ => unreachable!(),
184            }
185        }
186
187        Ok(())
188    }
189
190    fn flush_block(&mut self, last: bool) -> io::Result<()> {
191        if self.buffer.is_empty() && last {
192            self.block_out.clear();
193            block_encoder::encode_raw_block(&[], true, &mut self.block_out);
194            self.inner.write_all(&self.block_out)?;
195            return Ok(());
196        }
197
198        if self.buffer.is_empty() {
199            return Ok(());
200        }
201
202        let chunk = core::mem::take(&mut self.buffer);
203
204        self.block_out.clear();
205        self.block_out.reserve(chunk.len() + 32);
206        if crate::block_looks_incompressible(&chunk) {
207            block_encoder::encode_raw_block(&chunk, last, &mut self.block_out);
208        } else {
209            let use_prefix = self.first_block && self.dict.is_some();
210            if use_prefix {
211                let prefix = self.dict.as_ref().unwrap().content();
212                match self.params.strategy {
213                    Strategy::Fast => {
214                        fast::compress_fast_with_prefix_reuse(
215                            &chunk,
216                            &self.params,
217                            &self.rep_offsets,
218                            prefix,
219                            &mut self.dict_hash,
220                            &mut self.hash_table,
221                            &mut self.sequences,
222                            &mut self.combined,
223                        );
224                    }
225                    Strategy::DFast => {
226                        dfast::compress_dfast_with_prefix_reuse(
227                            &chunk,
228                            &self.params,
229                            &self.rep_offsets,
230                            prefix,
231                            &mut self.hash_table,
232                            &mut self.hash_long,
233                            &mut self.sequences,
234                            &mut self.combined,
235                        );
236                    }
237                }
238            } else {
239                self.hash_table.fill(0);
240                if !self.hash_long.is_empty() {
241                    self.hash_long.fill(0);
242                }
243                match self.params.strategy {
244                    Strategy::Fast => {
245                        fast::compress_fast_block(
246                            &chunk,
247                            0,
248                            chunk.len(),
249                            &self.params,
250                            &self.rep_offsets,
251                            &mut self.hash_table,
252                            &mut self.sequences,
253                        );
254                    }
255                    Strategy::DFast => {
256                        dfast::compress_dfast_block(
257                            &chunk,
258                            0,
259                            chunk.len(),
260                            &self.params,
261                            &self.rep_offsets,
262                            &mut self.hash_table,
263                            &mut self.hash_long,
264                            &mut self.sequences,
265                        );
266                    }
267                }
268            }
269            if self.params.force_raw_literals {
270                block_encoder::encode_compressed_block_raw(
271                    &chunk,
272                    &self.sequences,
273                    &mut self.rep_offsets,
274                    last,
275                    &mut self.block_out,
276                    &mut self.workspace,
277                );
278            } else {
279                block_encoder::encode_compressed_block(
280                    &chunk,
281                    &self.sequences,
282                    &mut self.rep_offsets,
283                    last,
284                    &mut self.block_out,
285                    &mut self.workspace,
286                );
287            }
288        }
289
290        self.first_block = false;
291        self.inner.write_all(&self.block_out)?;
292        Ok(())
293    }
294}
295
296fn alloc_hash_tables(params: &LevelParams) -> (Vec<u32>, Vec<u32>) {
297    match params.strategy {
298        Strategy::Fast => (vec![0u32; 1usize << params.hash_log], Vec::new()),
299        Strategy::DFast => (
300            vec![0u32; 1usize << params.chain_log],
301            vec![0u32; 1usize << params.hash_log],
302        ),
303    }
304}
305
306impl<W: Write> Write for FrameEncoder<W> {
307    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
308        if self.finished {
309            return Err(io::Error::other("encoder already finished"));
310        }
311
312        if !self.header_written {
313            self.write_header()?;
314        }
315
316        self.hasher.update(buf);
317
318        let mut consumed = 0;
319        while consumed < buf.len() {
320            let space = MAX_BLOCK_SIZE - self.buffer.len();
321            let n = space.min(buf.len() - consumed);
322            self.buffer.extend_from_slice(&buf[consumed..consumed + n]);
323            consumed += n;
324
325            if self.buffer.len() >= MAX_BLOCK_SIZE {
326                self.flush_block(false)?;
327            }
328        }
329
330        Ok(consumed)
331    }
332
333    fn flush(&mut self) -> io::Result<()> {
334        if !self.buffer.is_empty() {
335            self.flush_block(false)?;
336        }
337        self.inner.flush()
338    }
339}