Skip to main content

zrip_encode/
context.rs

1#[cfg(feature = "alloc")]
2use alloc::borrow::Cow;
3#[cfg(feature = "alloc")]
4use alloc::vec;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7
8use crate::block_encoder::{self, BlockEncodeWorkspace};
9use crate::strategy::{self, LevelParams, Strategy};
10use crate::{block_looks_incompressible, clamp_params_to_src_size, dfast, fast};
11use zrip_core::Sequence;
12use zrip_core::dict::Dictionary;
13use zrip_core::error::CompressError;
14use zrip_core::frame::{MAX_BLOCK_SIZE, ZSTD_MAGIC};
15use zrip_core::xxhash::xxh64;
16
17/// Reusable compression context that amortizes hash table and buffer allocations.
18///
19/// Holds internal state (hash tables, output buffer, block encoder workspace)
20/// across calls. Useful when compressing many small inputs in a loop.
21///
22/// ```
23/// let mut ctx = zrip::CompressContext::new(1).unwrap();
24/// for i in 0..10 {
25///     let data = format!("message {i}").repeat(100);
26///     let compressed = ctx.compress(data.as_bytes()).unwrap();
27///     assert!(compressed.len() < data.len());
28/// }
29/// ```
30pub struct CompressContext {
31    params: LevelParams,
32    dict: Option<Dictionary>,
33    hash_table: Vec<u32>,
34    hash_long: Vec<u32>,
35    sequences: Vec<Sequence>,
36    output: Vec<u8>,
37    workspace: BlockEncodeWorkspace,
38    combined: Vec<u8>,
39}
40
41impl CompressContext {
42    /// Creates a new context for the given compression level (-7..=4).
43    pub fn new(level: i32) -> Result<Self, CompressError> {
44        let params = strategy::level_params(level).ok_or(CompressError::InvalidLevel(level))?;
45        let (hash_table, hash_long) = match params.strategy {
46            Strategy::Fast => (vec![0u32; 1usize << params.hash_log], Vec::new()),
47            Strategy::DFast => (
48                vec![0u32; 1usize << params.chain_log],
49                vec![0u32; 1usize << params.hash_log],
50            ),
51        };
52        Ok(Self {
53            params,
54            dict: None,
55            hash_table,
56            hash_long,
57            sequences: Vec::new(),
58            output: Vec::new(),
59            workspace: BlockEncodeWorkspace::new(),
60            combined: Vec::new(),
61        })
62    }
63
64    /// Creates a new context with a pre-loaded dictionary.
65    pub fn with_dict(level: i32, dict: Dictionary) -> Result<Self, CompressError> {
66        let mut ctx = Self::new(level)?;
67        ctx.dict = Some(dict);
68        Ok(ctx)
69    }
70
71    /// Compresses `input` using the context's level and optional dictionary.
72    pub fn compress(&mut self, input: &[u8]) -> Result<Cow<'_, [u8]>, CompressError> {
73        let (dict_id, prefix, init_rep) = if let Some(ref d) = self.dict {
74            (Some(d.id()), d.content(), *d.rep_offsets())
75        } else {
76            (None, &[] as &[u8], [1u32, 4, 8])
77        };
78        compress_core(
79            input,
80            self.params,
81            dict_id,
82            prefix,
83            init_rep,
84            &mut self.hash_table,
85            &mut self.hash_long,
86            &mut self.sequences,
87            &mut self.output,
88            &mut self.workspace,
89            &mut self.combined,
90        )?;
91        Ok(self.take_or_borrow_output())
92    }
93
94    /// Compresses `input` using an ad-hoc dictionary (overrides the stored one).
95    pub fn compress_with_dict(
96        &mut self,
97        input: &[u8],
98        dict: &Dictionary,
99    ) -> Result<Cow<'_, [u8]>, CompressError> {
100        compress_core(
101            input,
102            self.params,
103            Some(dict.id()),
104            dict.content(),
105            *dict.rep_offsets(),
106            &mut self.hash_table,
107            &mut self.hash_long,
108            &mut self.sequences,
109            &mut self.output,
110            &mut self.workspace,
111            &mut self.combined,
112        )?;
113        Ok(self.take_or_borrow_output())
114    }
115
116    fn take_or_borrow_output(&mut self) -> Cow<'_, [u8]> {
117        if self.output.len() >= zrip_core::LARGE_OUTPUT_THRESHOLD {
118            Cow::Owned(core::mem::take(&mut self.output))
119        } else {
120            Cow::Borrowed(&self.output)
121        }
122    }
123}
124
125fn compress_core(
126    input: &[u8],
127    params: LevelParams,
128    dict_id: Option<u32>,
129    prefix: &[u8],
130    init_rep_offsets: [u32; 3],
131    hash_table: &mut Vec<u32>,
132    hash_long: &mut Vec<u32>,
133    sequences: &mut Vec<Sequence>,
134    output: &mut Vec<u8>,
135    workspace: &mut BlockEncodeWorkspace,
136    combined: &mut Vec<u8>,
137) -> Result<(), CompressError> {
138    let mut params = params;
139    clamp_params_to_src_size(&mut params, input.len());
140    let hash_size = match params.strategy {
141        Strategy::Fast => 1usize << params.hash_log,
142        Strategy::DFast => 1usize << params.chain_log,
143    };
144    let long_size = 1usize << params.hash_log;
145
146    workspace.prev_huffman = None;
147
148    output.clear();
149    output.reserve(input.len() + 32);
150    output.extend_from_slice(&ZSTD_MAGIC.to_le_bytes());
151
152    let fcs_size = if input.len() <= 255 {
153        1
154    } else if input.len() <= 0xFFFF + 256 {
155        2
156    } else if input.len() <= 0xFFFFFFFF {
157        4
158    } else {
159        8
160    };
161    let fcs_flag = match fcs_size {
162        1 => 0,
163        2 => 1,
164        4 => 2,
165        8 => 3,
166        _ => unreachable!(),
167    };
168
169    if let Some(did) = dict_id {
170        let dict_id_flag = if did <= 0xFF {
171            1u8
172        } else if did <= 0xFFFF {
173            2
174        } else {
175            3
176        };
177        let descriptor = 0x20 | 0x04 | (fcs_flag << 6) | dict_id_flag;
178        output.push(descriptor);
179        match dict_id_flag {
180            1 => output.push(did as u8),
181            2 => output.extend_from_slice(&(did as u16).to_le_bytes()),
182            3 => output.extend_from_slice(&did.to_le_bytes()),
183            _ => unreachable!(),
184        }
185    } else {
186        let descriptor = 0x20 | 0x04 | (fcs_flag << 6);
187        output.push(descriptor);
188    }
189
190    match fcs_size {
191        1 => output.push(input.len() as u8),
192        2 => {
193            let v = (input.len() - 256) as u16;
194            output.extend_from_slice(&v.to_le_bytes());
195        }
196        4 => output.extend_from_slice(&(input.len() as u32).to_le_bytes()),
197        8 => output.extend_from_slice(&(input.len() as u64).to_le_bytes()),
198        _ => unreachable!(),
199    }
200
201    if input.is_empty() {
202        block_encoder::encode_raw_block(&[], true, output);
203    } else {
204        let has_prefix = !prefix.is_empty();
205        let mut rep_offsets = init_rep_offsets;
206        let mut offset = 0;
207
208        if hash_table.len() != hash_size {
209            hash_table.resize(hash_size, 0);
210        }
211
212        match params.strategy {
213            Strategy::Fast => {
214                if has_prefix && input.len() <= MAX_BLOCK_SIZE {
215                    fast::compress_fast_with_prefix_reuse(
216                        input,
217                        &params,
218                        &rep_offsets,
219                        prefix,
220                        hash_table,
221                        sequences,
222                        combined,
223                    );
224                    if params.force_raw_literals {
225                        block_encoder::encode_compressed_block_raw(
226                            input,
227                            sequences,
228                            &mut rep_offsets,
229                            true,
230                            output,
231                            workspace,
232                        );
233                    } else {
234                        block_encoder::encode_compressed_block(
235                            input,
236                            sequences,
237                            &mut rep_offsets,
238                            true,
239                            output,
240                            workspace,
241                        );
242                    }
243                } else if has_prefix {
244                    combined.clear();
245                    combined.reserve(prefix.len() + input.len());
246                    combined.extend_from_slice(prefix);
247                    combined.extend_from_slice(input);
248                    let plen = prefix.len();
249                    fast::prefill_hash_table(combined, plen, params.hash_log, hash_table);
250
251                    while offset < input.len() {
252                        let chunk_size = (input.len() - offset).min(MAX_BLOCK_SIZE);
253                        let is_last = offset + chunk_size >= input.len();
254                        fast::compress_fast_block(
255                            combined,
256                            plen + offset,
257                            plen + offset + chunk_size,
258                            &params,
259                            &rep_offsets,
260                            hash_table,
261                            sequences,
262                        );
263                        if params.force_raw_literals {
264                            block_encoder::encode_compressed_block_raw(
265                                &input[offset..offset + chunk_size],
266                                sequences,
267                                &mut rep_offsets,
268                                is_last,
269                                output,
270                                workspace,
271                            );
272                        } else {
273                            block_encoder::encode_compressed_block(
274                                &input[offset..offset + chunk_size],
275                                sequences,
276                                &mut rep_offsets,
277                                is_last,
278                                output,
279                                workspace,
280                            );
281                        }
282                        offset += chunk_size;
283                    }
284                } else {
285                    hash_table.fill(0);
286                    while offset < input.len() {
287                        let chunk_size = (input.len() - offset).min(MAX_BLOCK_SIZE);
288                        let block_end = offset + chunk_size;
289                        let is_last = block_end >= input.len();
290
291                        if block_looks_incompressible(&input[offset..block_end]) {
292                            block_encoder::encode_raw_block(
293                                &input[offset..block_end],
294                                is_last,
295                                output,
296                            );
297                        } else {
298                            fast::compress_fast_block(
299                                input,
300                                offset,
301                                block_end,
302                                &params,
303                                &rep_offsets,
304                                hash_table,
305                                sequences,
306                            );
307                            if params.force_raw_literals {
308                                block_encoder::encode_compressed_block_raw(
309                                    &input[offset..block_end],
310                                    sequences,
311                                    &mut rep_offsets,
312                                    is_last,
313                                    output,
314                                    workspace,
315                                );
316                            } else {
317                                block_encoder::encode_compressed_block(
318                                    &input[offset..block_end],
319                                    sequences,
320                                    &mut rep_offsets,
321                                    is_last,
322                                    output,
323                                    workspace,
324                                );
325                            }
326                        }
327                        offset = block_end;
328                    }
329                }
330            }
331            Strategy::DFast => {
332                if hash_long.len() != long_size {
333                    hash_long.resize(long_size, 0);
334                }
335                if has_prefix && input.len() <= MAX_BLOCK_SIZE {
336                    dfast::compress_dfast_with_prefix_reuse(
337                        input,
338                        &params,
339                        &rep_offsets,
340                        prefix,
341                        hash_table,
342                        hash_long,
343                        sequences,
344                        combined,
345                    );
346                    block_encoder::encode_compressed_block(
347                        input,
348                        sequences,
349                        &mut rep_offsets,
350                        true,
351                        output,
352                        workspace,
353                    );
354                } else if has_prefix {
355                    combined.clear();
356                    combined.reserve(prefix.len() + input.len());
357                    combined.extend_from_slice(prefix);
358                    combined.extend_from_slice(input);
359                    let plen = prefix.len();
360                    dfast::prefill_hash_tables(
361                        combined,
362                        plen,
363                        params.hash_log,
364                        params.chain_log,
365                        hash_table,
366                        hash_long,
367                    );
368
369                    while offset < input.len() {
370                        let chunk_size = (input.len() - offset).min(MAX_BLOCK_SIZE);
371                        let is_last = offset + chunk_size >= input.len();
372                        dfast::compress_dfast_block(
373                            combined,
374                            plen + offset,
375                            plen + offset + chunk_size,
376                            &params,
377                            &rep_offsets,
378                            hash_table,
379                            hash_long,
380                            sequences,
381                        );
382                        block_encoder::encode_compressed_block(
383                            &input[offset..offset + chunk_size],
384                            sequences,
385                            &mut rep_offsets,
386                            is_last,
387                            output,
388                            workspace,
389                        );
390                        offset += chunk_size;
391                    }
392                } else {
393                    hash_table.fill(0);
394                    hash_long.fill(0);
395                    while offset < input.len() {
396                        let chunk_size = (input.len() - offset).min(MAX_BLOCK_SIZE);
397                        let block_end = offset + chunk_size;
398                        let is_last = block_end >= input.len();
399
400                        if block_looks_incompressible(&input[offset..block_end]) {
401                            block_encoder::encode_raw_block(
402                                &input[offset..block_end],
403                                is_last,
404                                output,
405                            );
406                        } else {
407                            dfast::compress_dfast_block(
408                                input,
409                                offset,
410                                block_end,
411                                &params,
412                                &rep_offsets,
413                                hash_table,
414                                hash_long,
415                                sequences,
416                            );
417                            block_encoder::encode_compressed_block(
418                                &input[offset..block_end],
419                                sequences,
420                                &mut rep_offsets,
421                                is_last,
422                                output,
423                                workspace,
424                            );
425                        }
426                        offset = block_end;
427                    }
428                }
429            }
430        }
431    }
432
433    let hash = xxh64(input, 0);
434    let checksum = (hash & 0xFFFFFFFF) as u32;
435    output.extend_from_slice(&checksum.to_le_bytes());
436
437    Ok(())
438}