1#![forbid(unsafe_code)]
2
3use std::io::{self, Write};
4
5use crate::block_encoder::{self, BlockEncodeWorkspace};
6use crate::dfast;
7use crate::fast;
8use crate::strategy::{self, LevelParams, Strategy};
9use zrip_core::Sequence;
10use zrip_core::dict::Dictionary;
11use zrip_core::error::CompressError;
12use zrip_core::frame::{MAX_BLOCK_SIZE, ZSTD_MAGIC};
13use zrip_core::xxhash::Xxh64State;
14
15pub struct FrameEncoder<W: Write> {
36 inner: W,
37 params: LevelParams,
38 buffer: Vec<u8>,
39 rep_offsets: [u32; 3],
40 hasher: Xxh64State,
41 header_written: bool,
42 finished: bool,
43 workspace: BlockEncodeWorkspace,
44 dict: Option<Dictionary>,
45 first_block: bool,
46 hash_table: Vec<u32>,
47 hash_long: Vec<u32>,
48 dict_hash: Vec<u32>,
49 sequences: Vec<Sequence>,
50 combined: Vec<u8>,
51 block_out: Vec<u8>,
52}
53
54impl<W: Write> FrameEncoder<W> {
55 pub fn new(writer: W, level: i32) -> Result<Self, CompressError> {
57 let params = strategy::level_params(level).ok_or(CompressError::InvalidLevel(level))?;
58 let (hash_table, hash_long) = alloc_hash_tables(¶ms);
59 Ok(Self {
60 inner: writer,
61 params,
62 buffer: Vec::new(),
63 rep_offsets: [1, 4, 8],
64 hasher: Xxh64State::new(0),
65 header_written: false,
66 finished: false,
67 workspace: BlockEncodeWorkspace::new(),
68 dict: None,
69 first_block: false,
70 hash_table,
71 hash_long,
72 dict_hash: Vec::new(),
73 sequences: Vec::new(),
74 combined: Vec::new(),
75 block_out: Vec::new(),
76 })
77 }
78
79 pub fn with_dict(writer: W, level: i32, dict: Dictionary) -> Result<Self, CompressError> {
81 let params = strategy::level_params(level).ok_or(CompressError::InvalidLevel(level))?;
82 let (hash_table, hash_long) = alloc_hash_tables(¶ms);
83 let dict_hash = vec![0u32; hash_table.len()];
84 let rep_offsets = *dict.rep_offsets();
85 Ok(Self {
86 inner: writer,
87 params,
88 buffer: Vec::new(),
89 rep_offsets,
90 hasher: Xxh64State::new(0),
91 header_written: false,
92 finished: false,
93 workspace: BlockEncodeWorkspace::new(),
94 dict: Some(dict),
95 first_block: true,
96 hash_table,
97 hash_long,
98 dict_hash,
99 sequences: Vec::new(),
100 combined: Vec::new(),
101 block_out: Vec::new(),
102 })
103 }
104
105 pub fn finish(mut self) -> Result<W, io::Error> {
107 self.finish_frame()?;
108 Ok(self.inner)
109 }
110
111 pub fn reset(&mut self, new_writer: W) -> Result<W, io::Error> {
117 self.finish_frame()?;
118 let old = core::mem::replace(&mut self.inner, new_writer);
119 self.header_written = false;
120 self.finished = false;
121 self.first_block = self.dict.is_some();
122 self.rep_offsets = match &self.dict {
123 Some(d) => *d.rep_offsets(),
124 None => [1, 4, 8],
125 };
126 self.hasher = Xxh64State::new(0);
127 self.workspace.prev_huffman = None;
128 Ok(old)
129 }
130
131 fn finish_frame(&mut self) -> io::Result<()> {
132 if self.finished {
133 return Ok(());
134 }
135 self.finished = true;
136
137 if !self.header_written {
138 self.write_header()?;
139 }
140
141 self.flush_block(true)?;
142
143 let hash = self.hasher.finish();
144 let checksum = (hash & 0xFFFF_FFFF) as u32;
145 self.inner.write_all(&checksum.to_le_bytes())?;
146 Ok(())
147 }
148
149 fn write_header(&mut self) -> io::Result<()> {
150 self.header_written = true;
151
152 self.inner.write_all(&ZSTD_MAGIC.to_le_bytes())?;
153
154 let window_log = self.params.window_log;
155
156 let dict_id_flag = if let Some(ref dict) = self.dict {
157 let id = dict.id();
158 if id <= 0xFF {
159 1u8
160 } else if id <= 0xFFFF {
161 2
162 } else {
163 3
164 }
165 } else {
166 0
167 };
168
169 let descriptor = 0x04u8 | dict_id_flag;
170 self.inner.write_all(&[descriptor])?;
171
172 let mantissa = 0u8;
173 let exponent = (window_log - 10) as u8;
174 let window_descriptor = (exponent << 3) | mantissa;
175 self.inner.write_all(&[window_descriptor])?;
176
177 if let Some(ref dict) = self.dict {
178 let id = dict.id();
179 match dict_id_flag {
180 1 => self.inner.write_all(&[id as u8])?,
181 2 => self.inner.write_all(&(id as u16).to_le_bytes())?,
182 3 => self.inner.write_all(&id.to_le_bytes())?,
183 _ => unreachable!(),
184 }
185 }
186
187 Ok(())
188 }
189
190 fn flush_block(&mut self, last: bool) -> io::Result<()> {
191 if self.buffer.is_empty() && last {
192 self.block_out.clear();
193 block_encoder::encode_raw_block(&[], true, &mut self.block_out);
194 self.inner.write_all(&self.block_out)?;
195 return Ok(());
196 }
197
198 if self.buffer.is_empty() {
199 return Ok(());
200 }
201
202 let chunk = core::mem::take(&mut self.buffer);
203
204 self.block_out.clear();
205 self.block_out.reserve(chunk.len() + 32);
206 if crate::block_looks_incompressible(&chunk) {
207 block_encoder::encode_raw_block(&chunk, last, &mut self.block_out);
208 } else {
209 let use_prefix = self.first_block && self.dict.is_some();
210 if use_prefix {
211 let prefix = self.dict.as_ref().unwrap().content();
212 match self.params.strategy {
213 Strategy::Fast => {
214 fast::compress_fast_with_prefix_reuse(
215 &chunk,
216 &self.params,
217 &self.rep_offsets,
218 prefix,
219 &mut self.dict_hash,
220 &mut self.hash_table,
221 &mut self.sequences,
222 &mut self.combined,
223 );
224 }
225 Strategy::DFast => {
226 dfast::compress_dfast_with_prefix_reuse(
227 &chunk,
228 &self.params,
229 &self.rep_offsets,
230 prefix,
231 &mut self.hash_table,
232 &mut self.hash_long,
233 &mut self.sequences,
234 &mut self.combined,
235 );
236 }
237 }
238 } else {
239 self.hash_table.fill(0);
240 if !self.hash_long.is_empty() {
241 self.hash_long.fill(0);
242 }
243 match self.params.strategy {
244 Strategy::Fast => {
245 fast::compress_fast_block(
246 &chunk,
247 0,
248 chunk.len(),
249 &self.params,
250 &self.rep_offsets,
251 &mut self.hash_table,
252 &mut self.sequences,
253 );
254 }
255 Strategy::DFast => {
256 dfast::compress_dfast_block(
257 &chunk,
258 0,
259 chunk.len(),
260 &self.params,
261 &self.rep_offsets,
262 &mut self.hash_table,
263 &mut self.hash_long,
264 &mut self.sequences,
265 );
266 }
267 }
268 }
269 if self.params.force_raw_literals {
270 block_encoder::encode_compressed_block_raw(
271 &chunk,
272 &self.sequences,
273 &mut self.rep_offsets,
274 last,
275 &mut self.block_out,
276 &mut self.workspace,
277 );
278 } else {
279 block_encoder::encode_compressed_block(
280 &chunk,
281 &self.sequences,
282 &mut self.rep_offsets,
283 last,
284 &mut self.block_out,
285 &mut self.workspace,
286 );
287 }
288 }
289
290 self.first_block = false;
291 self.inner.write_all(&self.block_out)?;
292 Ok(())
293 }
294}
295
296fn alloc_hash_tables(params: &LevelParams) -> (Vec<u32>, Vec<u32>) {
297 match params.strategy {
298 Strategy::Fast => (vec![0u32; 1usize << params.hash_log], Vec::new()),
299 Strategy::DFast => (
300 vec![0u32; 1usize << params.chain_log],
301 vec![0u32; 1usize << params.hash_log],
302 ),
303 }
304}
305
306impl<W: Write> Write for FrameEncoder<W> {
307 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
308 if self.finished {
309 return Err(io::Error::other("encoder already finished"));
310 }
311
312 if !self.header_written {
313 self.write_header()?;
314 }
315
316 self.hasher.update(buf);
317
318 let mut consumed = 0;
319 while consumed < buf.len() {
320 let space = MAX_BLOCK_SIZE - self.buffer.len();
321 let n = space.min(buf.len() - consumed);
322 self.buffer.extend_from_slice(&buf[consumed..consumed + n]);
323 consumed += n;
324
325 if self.buffer.len() >= MAX_BLOCK_SIZE {
326 self.flush_block(false)?;
327 }
328 }
329
330 Ok(consumed)
331 }
332
333 fn flush(&mut self) -> io::Result<()> {
334 if !self.buffer.is_empty() {
335 self.flush_block(false)?;
336 }
337 self.inner.flush()
338 }
339}