Skip to main content

zrip_encode/
context.rs

1#[cfg(feature = "alloc")]
2use alloc::borrow::Cow;
3#[cfg(feature = "alloc")]
4use alloc::vec;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7
8use crate::block_encoder::{self, BlockEncodeWorkspace};
9use crate::strategy::{self, LevelParams, Strategy};
10use crate::{block_looks_incompressible, clamp_params_to_src_size, dfast, fast};
11use zrip_core::Sequence;
12use zrip_core::dict::Dictionary;
13use zrip_core::error::CompressError;
14use zrip_core::frame::{MAX_BLOCK_SIZE, ZSTD_MAGIC};
15use zrip_core::xxhash::xxh64;
16
17/// Reusable compression context that amortizes hash table and buffer allocations.
18///
19/// Holds internal state (hash tables, output buffer, block encoder workspace)
20/// across calls. Useful when compressing many small inputs in a loop.
21///
22/// ```
23/// let mut ctx = zrip::CompressContext::new(1).unwrap();
24/// for i in 0..10 {
25///     let data = format!("message {i}").repeat(100);
26///     let compressed = ctx.compress(data.as_bytes()).unwrap();
27///     assert!(compressed.len() < data.len());
28/// }
29/// ```
30pub struct CompressContext {
31    params: LevelParams,
32    dict: Option<Dictionary>,
33    hash_table: Vec<u32>,
34    hash_long: Vec<u32>,
35    sequences: Vec<Sequence>,
36    output: Vec<u8>,
37    workspace: BlockEncodeWorkspace,
38    combined: Vec<u8>,
39}
40
41impl CompressContext {
42    /// Creates a new context for the given compression level (-7..=4).
43    pub fn new(level: i32) -> Result<Self, CompressError> {
44        let params = strategy::level_params(level).ok_or(CompressError::InvalidLevel(level))?;
45        let (hash_table, hash_long) = match params.strategy {
46            Strategy::Fast => (vec![0u32; 1usize << params.hash_log], Vec::new()),
47            Strategy::DFast => (
48                vec![0u32; 1usize << params.chain_log],
49                vec![0u32; 1usize << params.hash_log],
50            ),
51        };
52        Ok(Self {
53            params,
54            dict: None,
55            hash_table,
56            hash_long,
57            sequences: Vec::new(),
58            output: Vec::new(),
59            workspace: BlockEncodeWorkspace::new(),
60            combined: Vec::new(),
61        })
62    }
63
64    /// Creates a new context with a pre-loaded dictionary.
65    pub fn with_dict(level: i32, dict: Dictionary) -> Result<Self, CompressError> {
66        let mut ctx = Self::new(level)?;
67        ctx.dict = Some(dict);
68        Ok(ctx)
69    }
70
71    /// Compresses `input` using the context's level and optional dictionary.
72    pub fn compress(&mut self, input: &[u8]) -> Result<Cow<'_, [u8]>, CompressError> {
73        let (dict_id, prefix, init_rep) = if let Some(ref d) = self.dict {
74            (Some(d.id()), d.content(), *d.rep_offsets())
75        } else {
76            (None, &[] as &[u8], [1u32, 4, 8])
77        };
78        compress_core(
79            input,
80            self.params,
81            dict_id,
82            prefix,
83            init_rep,
84            &mut self.hash_table,
85            &mut self.hash_long,
86            &mut self.sequences,
87            &mut self.output,
88            &mut self.workspace,
89            &mut self.combined,
90        )?;
91        Ok(self.take_or_borrow_output())
92    }
93
94    /// Compresses `input` using an ad-hoc dictionary (overrides the stored one).
95    pub fn compress_with_dict(
96        &mut self,
97        input: &[u8],
98        dict: &Dictionary,
99    ) -> Result<Cow<'_, [u8]>, CompressError> {
100        compress_core(
101            input,
102            self.params,
103            Some(dict.id()),
104            dict.content(),
105            *dict.rep_offsets(),
106            &mut self.hash_table,
107            &mut self.hash_long,
108            &mut self.sequences,
109            &mut self.output,
110            &mut self.workspace,
111            &mut self.combined,
112        )?;
113        Ok(self.take_or_borrow_output())
114    }
115
116    fn take_or_borrow_output(&mut self) -> Cow<'_, [u8]> {
117        if self.output.len() >= zrip_core::LARGE_OUTPUT_THRESHOLD {
118            Cow::Owned(core::mem::take(&mut self.output))
119        } else {
120            Cow::Borrowed(&self.output)
121        }
122    }
123}
124
125#[allow(clippy::too_many_arguments, clippy::unnecessary_wraps)]
126fn compress_core(
127    input: &[u8],
128    params: LevelParams,
129    dict_id: Option<u32>,
130    prefix: &[u8],
131    init_rep_offsets: [u32; 3],
132    hash_table: &mut Vec<u32>,
133    hash_long: &mut Vec<u32>,
134    sequences: &mut Vec<Sequence>,
135    output: &mut Vec<u8>,
136    workspace: &mut BlockEncodeWorkspace,
137    combined: &mut Vec<u8>,
138) -> Result<(), CompressError> {
139    let mut params = params;
140    clamp_params_to_src_size(&mut params, input.len());
141    let hash_size = match params.strategy {
142        Strategy::Fast => 1usize << params.hash_log,
143        Strategy::DFast => 1usize << params.chain_log,
144    };
145    let long_size = 1usize << params.hash_log;
146
147    workspace.prev_huffman = None;
148
149    output.clear();
150    output.reserve(input.len() + 32);
151    output.extend_from_slice(&ZSTD_MAGIC.to_le_bytes());
152
153    let fcs_size = if input.len() <= 255 {
154        1
155    } else if input.len() <= 0xFFFF + 256 {
156        2
157    } else if input.len() <= 0xFFFF_FFFF {
158        4
159    } else {
160        8
161    };
162    let fcs_flag = match fcs_size {
163        1 => 0,
164        2 => 1,
165        4 => 2,
166        8 => 3,
167        _ => unreachable!(),
168    };
169
170    if let Some(did) = dict_id {
171        let dict_id_flag = if did <= 0xFF {
172            1u8
173        } else if did <= 0xFFFF {
174            2
175        } else {
176            3
177        };
178        let descriptor = 0x20 | 0x04 | (fcs_flag << 6) | dict_id_flag;
179        output.push(descriptor);
180        match dict_id_flag {
181            1 => output.push(did as u8),
182            2 => output.extend_from_slice(&(did as u16).to_le_bytes()),
183            3 => output.extend_from_slice(&did.to_le_bytes()),
184            _ => unreachable!(),
185        }
186    } else {
187        let descriptor = 0x20 | 0x04 | (fcs_flag << 6);
188        output.push(descriptor);
189    }
190
191    match fcs_size {
192        1 => output.push(input.len() as u8),
193        2 => {
194            let v = (input.len() - 256) as u16;
195            output.extend_from_slice(&v.to_le_bytes());
196        }
197        4 => output.extend_from_slice(&(input.len() as u32).to_le_bytes()),
198        8 => output.extend_from_slice(&(input.len() as u64).to_le_bytes()),
199        _ => unreachable!(),
200    }
201
202    if input.is_empty() {
203        block_encoder::encode_raw_block(&[], true, output);
204    } else {
205        let has_prefix = !prefix.is_empty();
206        let mut rep_offsets = init_rep_offsets;
207        let mut offset = 0;
208
209        if hash_table.len() != hash_size {
210            hash_table.resize(hash_size, 0);
211        }
212
213        match params.strategy {
214            Strategy::Fast => {
215                if has_prefix && input.len() <= MAX_BLOCK_SIZE {
216                    fast::compress_fast_with_prefix_reuse(
217                        input,
218                        &params,
219                        &rep_offsets,
220                        prefix,
221                        hash_table,
222                        sequences,
223                        combined,
224                    );
225                    if params.force_raw_literals {
226                        block_encoder::encode_compressed_block_raw(
227                            input,
228                            sequences,
229                            &mut rep_offsets,
230                            true,
231                            output,
232                            workspace,
233                        );
234                    } else {
235                        block_encoder::encode_compressed_block(
236                            input,
237                            sequences,
238                            &mut rep_offsets,
239                            true,
240                            output,
241                            workspace,
242                        );
243                    }
244                } else if has_prefix {
245                    combined.clear();
246                    combined.reserve(prefix.len() + input.len());
247                    combined.extend_from_slice(prefix);
248                    combined.extend_from_slice(input);
249                    let plen = prefix.len();
250                    fast::prefill_hash_table(combined, plen, params.hash_log, hash_table);
251
252                    while offset < input.len() {
253                        let chunk_size = (input.len() - offset).min(MAX_BLOCK_SIZE);
254                        let is_last = offset + chunk_size >= input.len();
255                        fast::compress_fast_block(
256                            combined,
257                            plen + offset,
258                            plen + offset + chunk_size,
259                            &params,
260                            &rep_offsets,
261                            hash_table,
262                            sequences,
263                        );
264                        if params.force_raw_literals {
265                            block_encoder::encode_compressed_block_raw(
266                                &input[offset..offset + chunk_size],
267                                sequences,
268                                &mut rep_offsets,
269                                is_last,
270                                output,
271                                workspace,
272                            );
273                        } else {
274                            block_encoder::encode_compressed_block(
275                                &input[offset..offset + chunk_size],
276                                sequences,
277                                &mut rep_offsets,
278                                is_last,
279                                output,
280                                workspace,
281                            );
282                        }
283                        offset += chunk_size;
284                    }
285                } else {
286                    hash_table.fill(0);
287                    while offset < input.len() {
288                        let chunk_size = (input.len() - offset).min(MAX_BLOCK_SIZE);
289                        let block_end = offset + chunk_size;
290                        let is_last = block_end >= input.len();
291
292                        if block_looks_incompressible(&input[offset..block_end]) {
293                            block_encoder::encode_raw_block(
294                                &input[offset..block_end],
295                                is_last,
296                                output,
297                            );
298                        } else {
299                            fast::compress_fast_block(
300                                input,
301                                offset,
302                                block_end,
303                                &params,
304                                &rep_offsets,
305                                hash_table,
306                                sequences,
307                            );
308                            if params.force_raw_literals {
309                                block_encoder::encode_compressed_block_raw(
310                                    &input[offset..block_end],
311                                    sequences,
312                                    &mut rep_offsets,
313                                    is_last,
314                                    output,
315                                    workspace,
316                                );
317                            } else {
318                                block_encoder::encode_compressed_block(
319                                    &input[offset..block_end],
320                                    sequences,
321                                    &mut rep_offsets,
322                                    is_last,
323                                    output,
324                                    workspace,
325                                );
326                            }
327                        }
328                        offset = block_end;
329                    }
330                }
331            }
332            Strategy::DFast => {
333                if hash_long.len() != long_size {
334                    hash_long.resize(long_size, 0);
335                }
336                if has_prefix && input.len() <= MAX_BLOCK_SIZE {
337                    dfast::compress_dfast_with_prefix_reuse(
338                        input,
339                        &params,
340                        &rep_offsets,
341                        prefix,
342                        hash_table,
343                        hash_long,
344                        sequences,
345                        combined,
346                    );
347                    block_encoder::encode_compressed_block(
348                        input,
349                        sequences,
350                        &mut rep_offsets,
351                        true,
352                        output,
353                        workspace,
354                    );
355                } else if has_prefix {
356                    combined.clear();
357                    combined.reserve(prefix.len() + input.len());
358                    combined.extend_from_slice(prefix);
359                    combined.extend_from_slice(input);
360                    let plen = prefix.len();
361                    dfast::prefill_hash_tables(
362                        combined,
363                        plen,
364                        params.hash_log,
365                        params.chain_log,
366                        hash_table,
367                        hash_long,
368                    );
369
370                    while offset < input.len() {
371                        let chunk_size = (input.len() - offset).min(MAX_BLOCK_SIZE);
372                        let is_last = offset + chunk_size >= input.len();
373                        dfast::compress_dfast_block(
374                            combined,
375                            plen + offset,
376                            plen + offset + chunk_size,
377                            &params,
378                            &rep_offsets,
379                            hash_table,
380                            hash_long,
381                            sequences,
382                        );
383                        block_encoder::encode_compressed_block(
384                            &input[offset..offset + chunk_size],
385                            sequences,
386                            &mut rep_offsets,
387                            is_last,
388                            output,
389                            workspace,
390                        );
391                        offset += chunk_size;
392                    }
393                } else {
394                    hash_table.fill(0);
395                    hash_long.fill(0);
396                    while offset < input.len() {
397                        let chunk_size = (input.len() - offset).min(MAX_BLOCK_SIZE);
398                        let block_end = offset + chunk_size;
399                        let is_last = block_end >= input.len();
400
401                        if block_looks_incompressible(&input[offset..block_end]) {
402                            block_encoder::encode_raw_block(
403                                &input[offset..block_end],
404                                is_last,
405                                output,
406                            );
407                        } else {
408                            dfast::compress_dfast_block(
409                                input,
410                                offset,
411                                block_end,
412                                &params,
413                                &rep_offsets,
414                                hash_table,
415                                hash_long,
416                                sequences,
417                            );
418                            block_encoder::encode_compressed_block(
419                                &input[offset..block_end],
420                                sequences,
421                                &mut rep_offsets,
422                                is_last,
423                                output,
424                                workspace,
425                            );
426                        }
427                        offset = block_end;
428                    }
429                }
430            }
431        }
432    }
433
434    let hash = xxh64(input, 0);
435    let checksum = (hash & 0xFFFF_FFFF) as u32;
436    output.extend_from_slice(&checksum.to_le_bytes());
437
438    Ok(())
439}