structured_zstd/decoding/
dictionary.rs1#[cfg(not(target_has_atomic = "ptr"))]
2use alloc::rc::Rc;
3#[cfg(target_has_atomic = "ptr")]
4use alloc::sync::Arc;
5use alloc::vec::Vec;
6use core::convert::TryInto;
7
8use crate::decoding::errors::DictionaryDecodeError;
9use crate::decoding::scratch::FSEScratch;
10use crate::decoding::scratch::HuffmanScratch;
11
12pub struct Dictionary {
17 pub id: u32,
20 pub fse: FSEScratch,
23 pub huf: HuffmanScratch,
26 pub dict_content: Vec<u8>,
36 pub offset_hist: [u32; 3],
41}
42
43#[cfg(target_has_atomic = "ptr")]
44type SharedDictionary = Arc<Dictionary>;
45#[cfg(not(target_has_atomic = "ptr"))]
46type SharedDictionary = Rc<Dictionary>;
47
48#[derive(Clone)]
52pub struct DictionaryHandle {
53 inner: SharedDictionary,
54}
55
56pub const MAGIC_NUM: [u8; 4] = [0x37, 0xA4, 0x30, 0xEC];
58
59impl Dictionary {
60 pub fn from_raw_content(
65 id: u32,
66 dict_content: Vec<u8>,
67 ) -> Result<Dictionary, DictionaryDecodeError> {
68 if id == 0 {
69 return Err(DictionaryDecodeError::ZeroDictionaryId);
70 }
71 if dict_content.is_empty() {
72 return Err(DictionaryDecodeError::DictionaryTooSmall { got: 0, need: 1 });
73 }
74
75 Ok(Dictionary {
76 id,
77 fse: FSEScratch::new(),
78 huf: HuffmanScratch::new(),
79 dict_content,
80 offset_hist: [1, 4, 8],
81 })
82 }
83
84 pub fn decode_dict(raw: &[u8]) -> Result<Dictionary, DictionaryDecodeError> {
88 const MIN_MAGIC_AND_ID_LEN: usize = 8;
89 const OFFSET_HISTORY_LEN: usize = 12;
90
91 if raw.len() < MIN_MAGIC_AND_ID_LEN {
92 return Err(DictionaryDecodeError::DictionaryTooSmall {
93 got: raw.len(),
94 need: MIN_MAGIC_AND_ID_LEN,
95 });
96 }
97
98 let mut new_dict = Dictionary {
99 id: 0,
100 fse: FSEScratch::new(),
101 huf: HuffmanScratch::new(),
102 dict_content: Vec::new(),
103 offset_hist: [1, 4, 8],
104 };
105
106 let magic_num: [u8; 4] = raw[..4].try_into().expect("optimized away");
107 if magic_num != MAGIC_NUM {
108 return Err(DictionaryDecodeError::BadMagicNum { got: magic_num });
109 }
110
111 let dict_id = raw[4..8].try_into().expect("optimized away");
112 let dict_id = u32::from_le_bytes(dict_id);
113 if dict_id == 0 {
114 return Err(DictionaryDecodeError::ZeroDictionaryId);
115 }
116 new_dict.id = dict_id;
117
118 let raw_tables = &raw[8..];
119
120 let huf_size = new_dict.huf.table.build_decoder(raw_tables)?;
121 let raw_tables = &raw_tables[huf_size as usize..];
122
123 let of_size = new_dict.fse.offsets.build_decoder(
124 raw_tables,
125 crate::decoding::sequence_section_decoder::OF_MAX_LOG,
126 )?;
127 let raw_tables = &raw_tables[of_size..];
128
129 let ml_size = new_dict.fse.match_lengths.build_decoder(
130 raw_tables,
131 crate::decoding::sequence_section_decoder::ML_MAX_LOG,
132 )?;
133 let raw_tables = &raw_tables[ml_size..];
134
135 let ll_size = new_dict.fse.literal_lengths.build_decoder(
136 raw_tables,
137 crate::decoding::sequence_section_decoder::LL_MAX_LOG,
138 )?;
139 let raw_tables = &raw_tables[ll_size..];
140
141 if raw_tables.len() < OFFSET_HISTORY_LEN {
142 return Err(DictionaryDecodeError::DictionaryTooSmall {
143 got: raw_tables.len(),
144 need: OFFSET_HISTORY_LEN,
145 });
146 }
147
148 let offset1 = raw_tables[0..4].try_into().expect("optimized away");
149 let offset1 = u32::from_le_bytes(offset1);
150
151 let offset2 = raw_tables[4..8].try_into().expect("optimized away");
152 let offset2 = u32::from_le_bytes(offset2);
153
154 let offset3 = raw_tables[8..12].try_into().expect("optimized away");
155 let offset3 = u32::from_le_bytes(offset3);
156
157 if offset1 == 0 {
158 return Err(DictionaryDecodeError::ZeroRepeatOffsetInDictionary { index: 0 });
159 }
160 if offset2 == 0 {
161 return Err(DictionaryDecodeError::ZeroRepeatOffsetInDictionary { index: 1 });
162 }
163 if offset3 == 0 {
164 return Err(DictionaryDecodeError::ZeroRepeatOffsetInDictionary { index: 2 });
165 }
166
167 new_dict.offset_hist[0] = offset1;
168 new_dict.offset_hist[1] = offset2;
169 new_dict.offset_hist[2] = offset3;
170
171 let raw_content = &raw_tables[12..];
172 new_dict.dict_content.extend(raw_content);
173
174 Ok(new_dict)
175 }
176
177 pub fn into_handle(self) -> DictionaryHandle {
179 DictionaryHandle::from_dictionary(self)
180 }
181}
182
183impl DictionaryHandle {
184 pub fn from_dictionary(dict: Dictionary) -> Self {
186 Self {
187 inner: SharedDictionary::new(dict),
188 }
189 }
190
191 pub fn decode_dict(raw: &[u8]) -> Result<Self, DictionaryDecodeError> {
193 Dictionary::decode_dict(raw).map(Self::from_dictionary)
194 }
195
196 pub fn id(&self) -> u32 {
197 self.inner.id
198 }
199
200 pub fn as_dict(&self) -> &Dictionary {
201 &self.inner
202 }
203}
204
205impl AsRef<Dictionary> for DictionaryHandle {
206 fn as_ref(&self) -> &Dictionary {
207 self.as_dict()
208 }
209}
210
211impl From<Dictionary> for DictionaryHandle {
212 fn from(dict: Dictionary) -> Self {
213 DictionaryHandle::from_dictionary(dict)
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use alloc::vec;
221
222 fn offset_history_start(raw: &[u8]) -> usize {
223 let mut huf = crate::decoding::scratch::HuffmanScratch::new();
224 let mut fse = crate::decoding::scratch::FSEScratch::new();
225 let mut cursor = 8usize;
226
227 let huf_size = huf
228 .table
229 .build_decoder(&raw[cursor..])
230 .expect("reference dictionary huffman table should decode");
231 cursor += huf_size as usize;
232
233 let of_size = fse
234 .offsets
235 .build_decoder(
236 &raw[cursor..],
237 crate::decoding::sequence_section_decoder::OF_MAX_LOG,
238 )
239 .expect("reference dictionary OF table should decode");
240 cursor += of_size;
241
242 let ml_size = fse
243 .match_lengths
244 .build_decoder(
245 &raw[cursor..],
246 crate::decoding::sequence_section_decoder::ML_MAX_LOG,
247 )
248 .expect("reference dictionary ML table should decode");
249 cursor += ml_size;
250
251 let ll_size = fse
252 .literal_lengths
253 .build_decoder(
254 &raw[cursor..],
255 crate::decoding::sequence_section_decoder::LL_MAX_LOG,
256 )
257 .expect("reference dictionary LL table should decode");
258 cursor += ll_size;
259
260 cursor
261 }
262
263 #[test]
264 fn decode_dict_rejects_short_buffer_before_magic_and_id() {
265 let err = match Dictionary::decode_dict(&[]) {
266 Ok(_) => panic!("expected short dictionary to fail"),
267 Err(err) => err,
268 };
269 assert!(matches!(
270 err,
271 DictionaryDecodeError::DictionaryTooSmall { got: 0, need: 8 }
272 ));
273 }
274
275 #[test]
276 fn decode_dict_malformed_input_returns_error_instead_of_panicking() {
277 let mut raw = Vec::new();
278 raw.extend_from_slice(&MAGIC_NUM);
279 raw.extend_from_slice(&1u32.to_le_bytes());
280 raw.extend_from_slice(&[0u8; 7]);
281
282 let result = std::panic::catch_unwind(|| Dictionary::decode_dict(&raw));
283 assert!(
284 result.is_ok(),
285 "decode_dict must not panic on malformed input"
286 );
287 assert!(
288 result.unwrap().is_err(),
289 "malformed dictionary must return error"
290 );
291 }
292
293 #[test]
294 fn decode_dict_rejects_zero_repeat_offsets() {
295 let mut raw = include_bytes!("../../dict_tests/dictionary").to_vec();
296 let offset_start = offset_history_start(&raw);
297
298 raw[offset_start..offset_start + 4].copy_from_slice(&0u32.to_le_bytes());
300 let decoded = Dictionary::decode_dict(&raw);
301 assert!(matches!(
302 decoded,
303 Err(DictionaryDecodeError::ZeroRepeatOffsetInDictionary { index: 0 })
304 ));
305 }
306
307 #[test]
308 fn from_raw_content_rejects_empty_dictionary_content() {
309 let result = Dictionary::from_raw_content(1, Vec::new());
310 assert!(matches!(
311 result,
312 Err(DictionaryDecodeError::DictionaryTooSmall { got: 0, need: 1 })
313 ));
314 }
315
316 #[test]
317 fn dictionary_handle_from_raw_content_supports_as_ref() {
318 let dict = Dictionary::from_raw_content(7, vec![42]).expect("raw dict should build");
319 let handle = dict.into_handle();
320 let dict_ref: &Dictionary = handle.as_ref();
321
322 assert_eq!(dict_ref.id, 7);
323 assert_eq!(dict_ref.dict_content.as_slice(), &[42]);
324 }
325
326 #[test]
327 fn dictionary_handle_clones_share_inner() {
328 let raw = include_bytes!("../../dict_tests/dictionary");
329 let handle = DictionaryHandle::decode_dict(raw).expect("dictionary should parse");
330 let clone = handle.clone();
331
332 assert_eq!(handle.id(), clone.id());
333 assert!(SharedDictionary::ptr_eq(&handle.inner, &clone.inner));
334 }
335}