Skip to main content

zrip_core/fse/
encode.rs

1#![forbid(unsafe_code)]
2
3#[cfg(feature = "alloc")]
4use alloc::vec;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7
8use crate::bitstream::writer::BitWriter;
9
10pub struct FseEncodeTable {
11    pub symbol_tt: Vec<SymbolTransform>,
12    pub state_table: Vec<u16>,
13    pub accuracy_log: u8,
14}
15
16#[derive(Clone, Copy)]
17pub struct SymbolTransform {
18    pub delta_nb_bits: u32,
19    pub delta_find_state: i32,
20}
21
22#[cfg(feature = "alloc")]
23impl FseEncodeTable {
24    pub fn from_distribution(distribution: &[i16], accuracy_log: u8) -> Self {
25        let table_size = 1usize << accuracy_log;
26        let mut symbol_tt = vec![
27            SymbolTransform {
28                delta_nb_bits: 0,
29                delta_find_state: 0,
30            };
31            distribution.len()
32        ];
33        let mut state_table = vec![0u16; table_size];
34
35        let step = (table_size >> 1) + (table_size >> 3) + 3;
36        let mask = table_size - 1;
37
38        let mut high_threshold = table_size - 1;
39        let mut position = 0;
40        let mut cumul = vec![0u32; distribution.len() + 1];
41
42        for (s, &prob) in distribution.iter().enumerate() {
43            if prob == -1 {
44                state_table[high_threshold] = s as u16;
45                high_threshold -= 1;
46                cumul[s + 1] = cumul[s] + 1;
47            } else {
48                cumul[s + 1] = cumul[s] + prob.max(0) as u32;
49            }
50        }
51
52        for (s, &prob) in distribution.iter().enumerate() {
53            if prob <= 0 {
54                continue;
55            }
56            for _ in 0..prob {
57                state_table[position] = s as u16;
58                position = (position + step) & mask;
59                while position > high_threshold {
60                    position = (position + step) & mask;
61                }
62            }
63        }
64
65        let mut next_state_number = vec![0u32; distribution.len()];
66        for s in 0..distribution.len() {
67            let prob = distribution[s];
68            if prob <= 0 {
69                let max_nb_bits = accuracy_log;
70                let min_state_plus = if prob == -1 { 1u32 } else { 0 };
71                symbol_tt[s].delta_nb_bits = ((max_nb_bits as u32 + 1) << 16) - min_state_plus;
72                symbol_tt[s].delta_find_state = 0;
73                next_state_number[s] = cumul[s];
74            } else if prob == 1 {
75                let max_bits_out = accuracy_log as u32;
76                let min_state_plus = 1u32 << accuracy_log;
77                symbol_tt[s].delta_nb_bits = (max_bits_out << 16).wrapping_sub(min_state_plus);
78                symbol_tt[s].delta_find_state = cumul[s] as i32 - 1;
79                next_state_number[s] = cumul[s];
80            } else {
81                let prob = prob as u32;
82                let max_bits_out = accuracy_log as u32 - high_bit(prob - 1);
83                let min_state_plus = prob << max_bits_out;
84                symbol_tt[s].delta_nb_bits = (max_bits_out << 16).wrapping_sub(min_state_plus);
85                symbol_tt[s].delta_find_state = (cumul[s] as i32) - (prob as i32);
86                next_state_number[s] = cumul[s];
87            }
88        }
89
90        // State table stores values in [table_size, 2*table_size),
91        // matching C zstd's FSE_buildCTable convention.
92        let mut table_symbol_sorted = vec![0u16; table_size];
93        for (i, &st) in state_table.iter().enumerate().take(table_size) {
94            let s = st as usize;
95            let ns = next_state_number[s];
96            table_symbol_sorted[ns as usize] = (table_size + i) as u16;
97            next_state_number[s] += 1;
98        }
99
100        Self {
101            symbol_tt,
102            state_table: table_symbol_sorted,
103            accuracy_log,
104        }
105    }
106}
107
108pub struct FseEncodeState<'t> {
109    table: &'t FseEncodeTable,
110    state: u32,
111}
112
113#[cfg(feature = "alloc")]
114impl<'t> FseEncodeState<'t> {
115    pub fn init(table: &'t FseEncodeTable, symbol: u8) -> Self {
116        let tt = &table.symbol_tt[symbol as usize];
117        let nb_bits_out = tt.delta_nb_bits.wrapping_add(1 << 16) >> 16;
118        let value = (nb_bits_out << 16).wrapping_sub(tt.delta_nb_bits);
119        let idx = (value >> nb_bits_out) as i32 + tt.delta_find_state;
120        let state = table.state_table[idx as usize] as u32;
121        Self { table, state }
122    }
123
124    pub fn encode_symbol(&mut self, writer: &mut BitWriter, symbol: u8) {
125        let tt = &self.table.symbol_tt[symbol as usize];
126        let nb_bits_out = (self.state.wrapping_add(tt.delta_nb_bits) >> 16) as u8;
127        writer.write_bits(self.state & ((1u32 << nb_bits_out) - 1), nb_bits_out);
128        self.state = self.table.state_table
129            [((self.state >> nb_bits_out) as i32 + tt.delta_find_state) as usize]
130            as u32;
131    }
132
133    pub fn flush(&self, writer: &mut BitWriter, accuracy_log: u8) {
134        writer.write_bits(self.state & ((1u32 << accuracy_log) - 1), accuracy_log);
135    }
136
137    pub fn state(&self) -> u32 {
138        self.state
139    }
140}
141
142fn high_bit(val: u32) -> u32 {
143    debug_assert!(val > 0);
144    31 - val.leading_zeros()
145}