zstd_framed/
table.rs

1pub mod futures;
2pub mod tokio;
3
4/// A table containing offsets and sizes for the frames within a zstd stream,
5/// such as from the [zstd seekable format].
6///
7/// ## Reading
8///
9/// If a zstd stream uses the [zstd seekable format], you can parse its
10/// seek table using the [`read_seek_table`] function (or one of its async
11/// variants).
12///
13/// ## Usage
14///
15/// [`ZstdReader`](crate::ZstdReader) can use a seek table to speed up
16/// seeks through the stream. To do so, pass the stream's seek table using
17/// the [`.with_seek_table()`](crate::reader::ZstdReaderBuilder::with_seek_table)
18/// builder option. When using [`AsyncZstdReader`](crate::AsyncZstdReader),
19/// you need to use both the `.with_seek_table()` builder option and the
20/// [`.seekable()`](crate::AsyncZstdReader::seekable) wrapper method.
21///
22/// ## Writing
23///
24/// [`ZstdWriter`](crate::ZstdWriter) can write a seek table by enabling the
25/// [`.with_seek_table()`](crate::writer::ZstdWriterBuilder::with_seek_table)
26/// builder option. The same applies when using [`AsyncZstdWriter`](crate::AsyncZstdWriter).
27///
28/// [zstd seekable format]: https://github.com/facebook/zstd/tree/51eb7daf39c8e8a7c338ba214a9d4e2a6a086826/contrib/seekable_format
29#[derive(Debug)]
30pub struct ZstdSeekTable {
31    frames: Vec<ZstdFrame>,
32}
33
34impl ZstdSeekTable {
35    pub(crate) fn empty() -> Self {
36        Self { frames: vec![] }
37    }
38
39    /// Returns the total number of zstd frames in the table.
40    pub fn num_frames(&self) -> usize {
41        self.frames.len()
42    }
43
44    /// Returns an iterator over each of the frames in the table. Frames
45    /// are ordered from the start of the zstd stream to the end.
46    pub fn frames(&self) -> impl Iterator<Item = ZstdFrame> + '_ {
47        self.frames.iter().copied()
48    }
49
50    /// Returns the first zstd frame in the table, or `None` if the table
51    /// is empty.
52    pub(crate) fn first_frame(&self) -> Option<ZstdFrame> {
53        self.frames.first().copied()
54    }
55
56    /// Returns the last zstd frame in the table, or `None` if the table
57    /// is empty.
58    pub(crate) fn last_frame(&self) -> Option<ZstdFrame> {
59        self.frames.last().copied()
60    }
61
62    pub(crate) fn find_by_decompressed_pos(&self, pos: u64) -> Option<ZstdFrame> {
63        let index = self
64            .frames
65            .binary_search_by(|frame| {
66                if pos < frame.decompressed_pos {
67                    std::cmp::Ordering::Greater
68                } else if pos >= frame.decompressed_pos + frame.size.decompressed_size {
69                    std::cmp::Ordering::Less
70                } else {
71                    std::cmp::Ordering::Equal
72                }
73            })
74            .ok()?;
75        let frame = self.frames[index];
76        Some(frame)
77    }
78
79    pub(crate) fn get(&self, index: usize) -> Option<ZstdFrame> {
80        self.frames.get(index).copied()
81    }
82
83    pub(crate) fn insert(&mut self, frame: ZstdFrame) {
84        let next_index = self.frames.len();
85
86        assert!(next_index >= frame.index);
87
88        if frame.index == next_index {
89            self.frames.push(frame);
90        } else if frame.index + 1 == next_index {
91            self.frames[frame.index] = frame;
92        }
93    }
94}
95
96/// Represents a single frame within a zstd stream, including the compressed
97/// and decompressed offsets within the stream, and the compressed and
98/// decompressed sizes of the frame.
99#[derive(Debug, Clone, Copy)]
100pub struct ZstdFrame {
101    pub(crate) index: usize,
102    pub(crate) compressed_pos: u64,
103    pub(crate) decompressed_pos: u64,
104    pub(crate) size: ZstdFrameSize,
105}
106
107impl ZstdFrame {
108    pub(crate) fn compressed_end(&self) -> u64 {
109        self.compressed_pos + self.size.compressed_size
110    }
111
112    pub(crate) fn decompressed_end(&self) -> u64 {
113        self.decompressed_pos + self.size.decompressed_size
114    }
115
116    /// Get the compressed size of the frame, measured in bytes.
117    pub fn compressed_size(&self) -> u64 {
118        self.size.compressed_size
119    }
120
121    /// Get the size of the frame if it were decompressed, measured in bytes.
122    pub fn decompressed_size(&self) -> u64 {
123        self.size.decompressed_size
124    }
125
126    /// Get the range of positions that cover the range of this frame
127    /// within the compressed zstd stream.
128    pub fn compressed_range(&self) -> std::ops::Range<u64> {
129        self.compressed_pos..self.compressed_end()
130    }
131
132    /// Get the range of positions that this frame would include if the
133    /// zstd stream were decompressed.
134    pub fn decompressed_range(&self) -> std::ops::Range<u64> {
135        self.decompressed_pos..self.decompressed_end()
136    }
137}
138
139#[derive(Debug, Default, Clone, Copy)]
140pub(crate) struct ZstdFrameSize {
141    pub(crate) compressed_size: u64,
142    pub(crate) decompressed_size: u64,
143}
144
145impl ZstdFrameSize {
146    pub(crate) fn add_sizes(&mut self, compressed_size: usize, decompressed_size: usize) {
147        let compressed_written: u64 = compressed_size
148            .try_into()
149            .expect("failed to convert written bytes to u64");
150        let decompressed_written: u64 = decompressed_size
151            .try_into()
152            .expect("failed to convert written bytes to u64");
153
154        let compressed_size = self
155            .compressed_size
156            .checked_add(compressed_written)
157            .expect("adding to compressed size overflowed");
158        let decompressed_size = self
159            .decompressed_size
160            .checked_add(decompressed_written)
161            .expect("adding to decompressed size overflowed");
162
163        self.compressed_size = compressed_size;
164        self.decompressed_size = decompressed_size;
165    }
166}
167
168/// Read the seek table from the end of a [zstd seekable format] stream.
169///
170/// Returns `Ok(None)` if the stream doesn't apper to contain a seek table.
171/// Otherwise, returns `Err(_)` if the seek table could not be parsed or
172/// if an I/O error occurred while trying to read the seek table. If it
173/// returns `Ok(_)`, it will also restore the reader to its original
174/// stream position.
175///
176/// The seek table is returned as-is from the underlying reader. No attempt
177/// is made to validate that the seek table lines up with the underlying
178/// zstd stream. This means a malformed seek table could have out-of-bounds
179/// offsets, could omit sections of the underyling stream, or could be
180/// misaligned from frames of the underlying stream.
181///
182/// Async implementations:
183///
184/// - `tokio`: [`crate::table::tokio::read_seek_table`]
185/// - `futures`: [`crate::table::futures::read_seek_table`]
186///
187/// [zstd seekable format]: https://github.com/facebook/zstd/tree/51eb7daf39c8e8a7c338ba214a9d4e2a6a086826/contrib/seekable_format
188pub fn read_seek_table<R>(reader: &mut R) -> std::io::Result<Option<ZstdSeekTable>>
189where
190    R: std::io::Read + std::io::Seek,
191{
192    // Get the stream position, so we can restore it later
193    let initial_position = reader.stream_position()?;
194
195    // Read the seek table
196    let seek_table_result = read_seek_table_inner(reader);
197
198    // Try to restore the seek position, even if reading
199    // the seek table failed
200    let seek_result = reader.seek(std::io::SeekFrom::Start(initial_position));
201
202    // If we got an error, return whichever we got first
203    let seek_table = seek_table_result?;
204    seek_result?;
205
206    Ok(seek_table)
207}
208
209fn read_seek_table_inner<R>(reader: &mut R) -> std::io::Result<Option<ZstdSeekTable>>
210where
211    R: std::io::Read + std::io::Seek,
212{
213    // Seek to the start of the zstd seek table footer
214    reader.seek(std::io::SeekFrom::End(-9))?;
215
216    // Read the footer fields: number of frames (4 bytes),
217    // table descriptor (1 byte), and the magic number (4 bytes)
218    let mut num_frames_bytes = [0; 4];
219    reader.read_exact(&mut num_frames_bytes)?;
220
221    let mut seek_table_descriptor_bytes = [0; 1];
222    reader.read_exact(&mut seek_table_descriptor_bytes)?;
223
224    let mut seekable_magic_number_bytes = [0; 4];
225    reader.read_exact(&mut seekable_magic_number_bytes)?;
226
227    // Return if the magic number doesn't match
228    if seekable_magic_number_bytes != crate::SEEKABLE_FOOTER_MAGIC_BYTES {
229        return Ok(None);
230    }
231
232    // Parse the number of frames
233    let num_frames = u32::from_le_bytes(num_frames_bytes);
234
235    // Validate the seek table descriptor
236    let [seek_table_descriptor] = seek_table_descriptor_bytes;
237    let has_checksum = seek_table_descriptor & 0b1000_0000 != 0;
238    let is_reserved_valid = seek_table_descriptor & 0b0111_1100 == 0;
239
240    if !is_reserved_valid {
241        return Err(std::io::Error::other(
242            "zstd seek table has unsupported descriptor",
243        ));
244    }
245
246    // Determine the table entry size (8 bytes, or 12 bytes with checksums)
247    let table_entry_size: u32 = if has_checksum { 12 } else { 8 };
248
249    // Calculate the full size of the skippable frame containing the
250    // seek table. This can't overflow for a valid seek table, since the
251    // frame size is part of the frame header.
252    let table_frame_size = table_entry_size
253        .checked_mul(num_frames)
254        .and_then(|size| size.checked_add(9))
255        .ok_or_else(|| std::io::Error::other("zstd seek table size overflowed"))?;
256
257    // Seek to the start of the skippable frame containing the seek table
258    reader.seek_relative(-i64::from(table_frame_size) - 8)?;
259
260    // Read the skippable frame magic number header: the
261    // magic number (4 bytes) and the frame size (4 bytes)
262    let mut skippable_magic_number_bytes = [0; 4];
263    reader.read_exact(&mut skippable_magic_number_bytes)?;
264
265    let mut actual_table_frame_size_bytes = [0; 4];
266    reader.read_exact(&mut actual_table_frame_size_bytes)?;
267
268    // Validate the skippable frame magic number and frame size
269    if skippable_magic_number_bytes != crate::SKIPPABLE_HEADER_MAGIC_BYTES {
270        return Err(std::io::Error::other(
271            "zstd seek table has unsupported skippable frame magic number",
272        ));
273    }
274
275    let actual_table_frame_size = u32::from_le_bytes(actual_table_frame_size_bytes);
276    if actual_table_frame_size != table_frame_size {
277        return Err(std::io::Error::other("zstd seek table size did not match"));
278    }
279
280    // Read each table entry
281    let mut table = ZstdSeekTable::empty();
282    let mut compressed_pos = 0;
283    let mut decompressed_pos = 0;
284    for frame_index in 0..num_frames {
285        let frame_index = usize::try_from(frame_index).unwrap();
286
287        // Read the compressed size
288        let mut compressed_size_bytes = [0; 4];
289        reader.read_exact(&mut compressed_size_bytes)?;
290        let compressed_size = u32::from_le_bytes(compressed_size_bytes);
291
292        // Read the decompressed size
293        let mut decompressed_size_bytes = [0; 4];
294        reader.read_exact(&mut decompressed_size_bytes)?;
295        let decompressed_size = u32::from_le_bytes(decompressed_size_bytes);
296
297        // Skip the checksum if present
298        if has_checksum {
299            reader.seek_relative(4)?;
300        }
301
302        let frame = ZstdFrame {
303            compressed_pos,
304            decompressed_pos,
305            index: frame_index,
306            size: ZstdFrameSize {
307                compressed_size: compressed_size.into(),
308                decompressed_size: decompressed_size.into(),
309            },
310        };
311        table.insert(frame);
312
313        compressed_pos += u64::from(compressed_size);
314        decompressed_pos += u64::from(decompressed_size);
315    }
316
317    Ok(Some(table))
318}