Skip to main content

rust_zstd/
fse.rs

1//! FSE (Finite State Entropy) encoder.
2//! Ported from zstd C source: lib/common/fse.h, lib/compress/fse_compress.c.
3
4use super::bitstream::BackwardBitWriter;
5
6/// Per-symbol compression transform (matches C FSE_symbolCompressionTransform).
7#[derive(Clone, Copy, Default)]
8pub struct SymbolTT {
9    pub delta_find_state: i32,
10    pub delta_nb_bits: u32,
11}
12
13/// Compiled FSE compression table.
14pub struct FseCTable {
15    pub table_log: u32,
16    pub state_table: Vec<u16>,
17    pub symbol_tt: Vec<SymbolTT>,
18    pub max_symbol: usize,
19}
20
21impl FseCTable {
22    /// Build an FSE compression table from normalized counts.
23    /// Ported from FSE_buildCTable_wksp() in fse_compress.c.
24    pub fn build(norm: &[i16], max_symbol: usize, table_log: u32) -> Self {
25        let table_size = 1u32 << table_log;
26        let table_mask = table_size - 1;
27
28        // 1. Build cumulative counts and place low-probability symbols.
29        let mut cumul = vec![0u16; max_symbol + 2];
30        let mut high_threshold = table_size - 1;
31        let mut table_symbol = vec![0u8; table_size as usize];
32
33        for s in 0..=max_symbol {
34            if norm[s] == -1 {
35                cumul[s + 1] = cumul[s] + 1;
36                table_symbol[high_threshold as usize] = s as u8;
37                high_threshold = high_threshold.wrapping_sub(1);
38            } else {
39                cumul[s + 1] = cumul[s] + norm[s] as u16;
40            }
41        }
42        cumul[max_symbol + 1] = (table_size + 1) as u16;
43
44        // 2. Spread symbols into the table using the exact C formula.
45        let step = (table_size >> 1) + (table_size >> 3) + 3;
46        let mut pos = 0u32;
47        for s in 0..=max_symbol {
48            let count = if norm[s] <= 0 { 0 } else { norm[s] as u32 };
49            for _ in 0..count {
50                table_symbol[pos as usize] = s as u8;
51                pos = (pos + step) & table_mask;
52                while pos > high_threshold {
53                    pos = (pos + step) & table_mask;
54                }
55            }
56        }
57        debug_assert_eq!(pos, 0);
58
59        // 3. Build state transition table sorted by symbol order.
60        let mut state_table = vec![0u16; table_size as usize];
61        for u in 0..table_size {
62            let s = table_symbol[u as usize] as usize;
63            let idx = cumul[s] as usize;
64            state_table[idx] = (table_size + u) as u16;
65            cumul[s] += 1;
66        }
67
68        // 4. Build per-symbol compression transforms.
69        // Use the decoder's baseline/numbits calculation for compatibility.
70        // For each state in the table, compute its decoder-compatible numbits,
71        // then derive the CTable's delta_nb_bits and delta_find_state from that.
72        let mut symbol_tt = vec![SymbolTT::default(); max_symbol + 1];
73        let _sym_count_tt = vec![0u32; max_symbol + 1];
74        let mut total = 0u32;
75        for s in 0..=max_symbol {
76            let prob = if norm[s] == -1 {
77                1
78            } else {
79                norm[s].max(0) as u32
80            };
81            if prob == 0 {
82                symbol_tt[s].delta_nb_bits = ((table_log + 1) << 16) - table_size;
83            } else if prob == 1 {
84                symbol_tt[s].delta_nb_bits = (table_log << 16) - table_size;
85                symbol_tt[s].delta_find_state = total as i32 - 1;
86            } else {
87                // Use the same formula as C zstd FSE_buildCTable
88                let max_bits_out = table_log - highest_bit(prob - 1);
89                let min_state_plus = prob << max_bits_out;
90                symbol_tt[s].delta_nb_bits = (max_bits_out << 16).wrapping_sub(min_state_plus);
91                symbol_tt[s].delta_find_state = total as i32 - prob as i32;
92            }
93            total += prob;
94        }
95
96        Self {
97            table_log,
98            state_table,
99            symbol_tt,
100            max_symbol,
101        }
102    }
103
104    /// Build an RLE compression table: single symbol, 0 bits per encode.
105    /// init_state returns 0, encode_symbol always returns (0, 0, 0).
106    /// table_log = 0, matching the decoder's RLE behavior.
107    pub fn build_rle(symbol: u8) -> Self {
108        let s = symbol as usize;
109        let max_symbol = s;
110        // Handcraft a table where everything resolves to 0 bits, state=0.
111        let state_table = vec![0u16; 1]; // state_table[0] = 0
112        let mut symbol_tt = vec![SymbolTT::default(); max_symbol + 1];
113        // We need encode_symbol to return (0, 0, 0):
114        // nb_bits = (state + delta_nb_bits) >> 16
115        //   For state=0: nb_bits = delta_nb_bits >> 16 = 0 (if delta_nb_bits < 65536)
116        // bits_out = state & ((1 << 0) - 1) = state & 0 = 0
117        // new_state = state_table[(state >> 0) + delta_find_state]
118        //   state >> 0 = 0, delta_find_state = 0 → state_table[0] = 0
119        symbol_tt[s] = SymbolTT {
120            delta_find_state: 0,
121            delta_nb_bits: 0,
122        };
123        Self {
124            table_log: 0,
125            state_table,
126            symbol_tt,
127            max_symbol,
128        }
129    }
130
131    /// Initialize FSE state for the first symbol (FSE_initCState2).
132    pub fn init_state(&self, symbol: usize) -> u32 {
133        let stt = &self.symbol_tt[symbol];
134        let nb_bits = ((stt.delta_nb_bits as u64 + (1 << 15)) >> 16) as u32;
135        let base_val = (nb_bits << 16).wrapping_sub(stt.delta_nb_bits);
136        self.state_table[((base_val >> nb_bits) as i32 + stt.delta_find_state) as usize] as u32
137    }
138
139    /// Encode a symbol: output bits from current state, then transition.
140    /// Returns (bits_to_output, nb_bits, new_state).
141    pub fn encode_symbol(&self, state: u32, symbol: usize) -> (u32, u32, u32) {
142        let stt = &self.symbol_tt[symbol];
143        let nb_bits = (state.wrapping_add(stt.delta_nb_bits)) >> 16;
144        let bits_out = state & ((1 << nb_bits) - 1);
145        let new_state =
146            self.state_table[((state >> nb_bits) as i32 + stt.delta_find_state) as usize] as u32;
147        (bits_out, nb_bits, new_state)
148    }
149}
150
151fn highest_bit(v: u32) -> u32 {
152    if v == 0 {
153        return 0;
154    }
155    31 - v.leading_zeros()
156}
157
158#[allow(clippy::too_many_arguments)]
159/// Encode sequences using predefined FSE tables.
160/// Exact port of ZSTD_encodeSequences_body from zstd_compress_sequences.c.
161pub fn encode_sequences(
162    ll_table: &FseCTable,
163    off_table: &FseCTable,
164    ml_table: &FseCTable,
165    ll_codes: &[u8],
166    off_codes: &[u8],
167    ml_codes: &[u8],
168    ll_values: &[u32],  // literal length values (for extra bits)
169    ml_values: &[u32],  // match length - MINMATCH values (for extra bits)
170    off_values: &[u32], // offset values (for extra bits)
171) -> Vec<u8> {
172    use super::constants::*;
173
174    let nb_seq = ll_codes.len();
175    if nb_seq == 0 {
176        return vec![];
177    }
178
179    let mut bw = BackwardBitWriter::new();
180
181    // Initialize states from the last sequence (first in encoding order)
182    let last = nb_seq - 1;
183    let mut state_ll = ll_table.init_state(ll_codes[last] as usize);
184    let mut state_off = off_table.init_state(off_codes[last] as usize);
185    let mut state_ml = ml_table.init_state(ml_codes[last] as usize);
186
187    // Encode extra bits for the last sequence
188    let ll_bits_n = LL_BITS[ll_codes[last] as usize] as u32;
189    bw.add_bits(ll_values[last] as u64, ll_bits_n);
190    if ll_bits_n > 0 {
191        bw.flush_bits();
192    }
193
194    let ml_bits_n = ML_BITS[ml_codes[last] as usize] as u32;
195    bw.add_bits(ml_values[last] as u64, ml_bits_n);
196    if ml_bits_n > 0 {
197        bw.flush_bits();
198    }
199
200    let of_bits_n = off_codes[last] as u32;
201    bw.add_bits(off_values[last] as u64, of_bits_n);
202    bw.flush_bits();
203
204    // Encode remaining sequences in reverse order
205    if nb_seq >= 2 {
206        for n in (0..last).rev() {
207            let llc = ll_codes[n] as usize;
208            let ofc = off_codes[n] as usize;
209            let mlc = ml_codes[n] as usize;
210
211            // FSE encode: OFF, ML, LL (order matters!)
212            let (bits, nb, new_state) = off_table.encode_symbol(state_off, ofc);
213            bw.add_bits(bits as u64, nb);
214            state_off = new_state;
215
216            let (bits, nb, new_state) = ml_table.encode_symbol(state_ml, mlc);
217            bw.add_bits(bits as u64, nb);
218            state_ml = new_state;
219
220            let (bits, nb, new_state) = ll_table.encode_symbol(state_ll, llc);
221            bw.add_bits(bits as u64, nb);
222            state_ll = new_state;
223
224            bw.flush_bits();
225
226            // Extra bits: LL, ML, OFF
227            let ll_eb = LL_BITS[llc] as u32;
228            bw.add_bits(ll_values[n] as u64, ll_eb);
229
230            let ml_eb = ML_BITS[mlc] as u32;
231            bw.add_bits(ml_values[n] as u64, ml_eb);
232
233            let of_eb = ofc as u32;
234            bw.add_bits(off_values[n] as u64, of_eb);
235            bw.flush_bits();
236        }
237    }
238
239    // Flush final states
240    bw.add_bits(state_ml as u64, ml_table.table_log);
241    bw.flush_bits();
242    bw.add_bits(state_off as u64, off_table.table_log);
243    bw.flush_bits();
244    bw.add_bits(state_ll as u64, ll_table.table_log);
245    bw.flush_bits();
246
247    bw.finish()
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use crate::constants::*;
254
255    #[test]
256    fn build_ll_default_table() {
257        let table = FseCTable::build(&LL_DEFAULT_NORM, MAX_LL, LL_DEFAULT_NORM_LOG);
258        assert_eq!(table.table_log, 6);
259        assert_eq!(table.state_table.len(), 64);
260    }
261
262    #[test]
263    fn build_ml_default_table() {
264        let table = FseCTable::build(&ML_DEFAULT_NORM, MAX_ML, ML_DEFAULT_NORM_LOG);
265        assert_eq!(table.table_log, 6);
266        assert_eq!(table.state_table.len(), 64);
267    }
268
269    #[test]
270    fn init_state_in_range() {
271        let table = FseCTable::build(&LL_DEFAULT_NORM, MAX_LL, LL_DEFAULT_NORM_LOG);
272        // Encoder states are stored in the [table_size, 2 * table_size) range.
273        let state = table.init_state(0);
274        let table_size = 1u32 << table.table_log;
275        assert!((table_size..(table_size * 2)).contains(&state));
276    }
277}