1#![forbid(unsafe_code)]
2
3#[cfg(feature = "alloc")]
4use alloc::vec;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7
8use crate::bitstream::writer::BitWriter;
9
10pub struct FseEncodeTable {
11 pub symbol_tt: Vec<SymbolTransform>,
12 pub state_table: Vec<u16>,
13 pub accuracy_log: u8,
14}
15
16#[derive(Clone, Copy)]
17pub struct SymbolTransform {
18 pub delta_nb_bits: u32,
19 pub delta_find_state: i32,
20}
21
22#[cfg(feature = "alloc")]
23impl FseEncodeTable {
24 pub fn from_distribution(distribution: &[i16], accuracy_log: u8) -> Self {
25 let table_size = 1usize << accuracy_log;
26 let mut symbol_tt = vec![
27 SymbolTransform {
28 delta_nb_bits: 0,
29 delta_find_state: 0,
30 };
31 distribution.len()
32 ];
33 let mut state_table = vec![0u16; table_size];
34
35 let step = (table_size >> 1) + (table_size >> 3) + 3;
36 let mask = table_size - 1;
37
38 let mut high_threshold = table_size - 1;
39 let mut position = 0;
40 let mut cumul = vec![0u32; distribution.len() + 1];
41
42 for (s, &prob) in distribution.iter().enumerate() {
43 if prob == -1 {
44 state_table[high_threshold] = s as u16;
45 high_threshold -= 1;
46 cumul[s + 1] = cumul[s] + 1;
47 } else {
48 cumul[s + 1] = cumul[s] + prob.max(0) as u32;
49 }
50 }
51
52 for (s, &prob) in distribution.iter().enumerate() {
53 if prob <= 0 {
54 continue;
55 }
56 for _ in 0..prob {
57 state_table[position] = s as u16;
58 position = (position + step) & mask;
59 while position > high_threshold {
60 position = (position + step) & mask;
61 }
62 }
63 }
64
65 let mut next_state_number = vec![0u32; distribution.len()];
66 for s in 0..distribution.len() {
67 let prob = distribution[s];
68 if prob <= 0 {
69 let max_nb_bits = accuracy_log;
70 let min_state_plus = if prob == -1 { 1u32 } else { 0 };
71 symbol_tt[s].delta_nb_bits = ((max_nb_bits as u32 + 1) << 16) - min_state_plus;
72 symbol_tt[s].delta_find_state = 0;
73 next_state_number[s] = cumul[s];
74 } else if prob == 1 {
75 let max_bits_out = accuracy_log as u32;
76 let min_state_plus = 1u32 << accuracy_log;
77 symbol_tt[s].delta_nb_bits = (max_bits_out << 16).wrapping_sub(min_state_plus);
78 symbol_tt[s].delta_find_state = cumul[s] as i32 - 1;
79 next_state_number[s] = cumul[s];
80 } else {
81 let prob = prob as u32;
82 let max_bits_out = accuracy_log as u32 - high_bit(prob - 1);
83 let min_state_plus = prob << max_bits_out;
84 symbol_tt[s].delta_nb_bits = (max_bits_out << 16).wrapping_sub(min_state_plus);
85 symbol_tt[s].delta_find_state = (cumul[s] as i32) - (prob as i32);
86 next_state_number[s] = cumul[s];
87 }
88 }
89
90 let mut table_symbol_sorted = vec![0u16; table_size];
93 for (i, &st) in state_table.iter().enumerate().take(table_size) {
94 let s = st as usize;
95 let ns = next_state_number[s];
96 table_symbol_sorted[ns as usize] = (table_size + i) as u16;
97 next_state_number[s] += 1;
98 }
99
100 Self {
101 symbol_tt,
102 state_table: table_symbol_sorted,
103 accuracy_log,
104 }
105 }
106}
107
108pub struct FseEncodeState<'t> {
109 table: &'t FseEncodeTable,
110 state: u32,
111}
112
113#[cfg(feature = "alloc")]
114impl<'t> FseEncodeState<'t> {
115 pub fn init(table: &'t FseEncodeTable, symbol: u8) -> Self {
116 let tt = &table.symbol_tt[symbol as usize];
117 let nb_bits_out = tt.delta_nb_bits.wrapping_add(1 << 16) >> 16;
118 let value = (nb_bits_out << 16).wrapping_sub(tt.delta_nb_bits);
119 let idx = (value >> nb_bits_out) as i32 + tt.delta_find_state;
120 let state = table.state_table[idx as usize] as u32;
121 Self { table, state }
122 }
123
124 pub fn encode_symbol(&mut self, writer: &mut BitWriter, symbol: u8) {
125 let tt = &self.table.symbol_tt[symbol as usize];
126 let nb_bits_out = (self.state.wrapping_add(tt.delta_nb_bits) >> 16) as u8;
127 writer.write_bits(self.state & ((1u32 << nb_bits_out) - 1), nb_bits_out);
128 self.state = self.table.state_table
129 [((self.state >> nb_bits_out) as i32 + tt.delta_find_state) as usize]
130 as u32;
131 }
132
133 pub fn flush(&self, writer: &mut BitWriter, accuracy_log: u8) {
134 writer.write_bits(self.state & ((1u32 << accuracy_log) - 1), accuracy_log);
135 }
136
137 pub fn state(&self) -> u32 {
138 self.state
139 }
140}
141
142fn high_bit(val: u32) -> u32 {
143 debug_assert!(val > 0);
144 31 - val.leading_zeros()
145}