zstd_framed/
reader.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
use std::io::BufRead as _;

use crate::{buffer::Buffer, decoder::ZstdFramedDecoder, table::ZstdSeekTable};

/// A reader that decompresses a zstd stream from an underlying reader.
///
/// The underyling reader `R` should implement the following traits:
///
/// - [`std::io::BufRead`] (required for [`std::io::Read`] and [`std::io::BufRead`] impls)
/// - (Optional) [`std::io::Seek`] (required for [`std::io::Seek`] impl)
///
/// For async support, see [`crate::AsyncZstdReader`].
///
/// ## Construction
///
/// Create a builder using either [`ZstdReader::builder`] (recommended) or
/// [`ZstdReader::builder_buffered`] (to use a custom buffer).
/// See [`ZstdReaderBuilder`] for build options. Call
/// [`ZstdReaderBuilder::build`] to build the [`ZstdReader`] instance.
///
/// ```
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let compressed_file: &[u8] = &[];
/// let reader = zstd_framed::ZstdReader::builder(compressed_file)
///     // .with_seek_table(table) // Provide a seek table if available
///     .build()?;
/// # Ok(())
/// # }
/// ```
///
/// ## Buffering
///
/// The decompressed zstd output is always buffered internally. Since the
/// reader must also implement [`std::io::BufRead`], the compressed input
/// must also be buffered.
///
/// [`ZstdReader::builder`] will wrap any reader implmenting [`std::io::Read`]
/// with a recommended buffer size for the input stream. For more control
/// over how the input gets buffered, you can instead use
/// [`ZstdReader::builder_buffered`].
///
/// ## Seeking
///
/// [`ZstdReader`] implements [`std::io::Seek`] as long as the underlying
/// reader implements [`std::io::Seek`]. **By default, seeking within the
/// stream will linearly decompress until reaching the target!**
///
/// Seeking can do a lot better when the underlying stream is broken up
/// into multiple frames, such as a stream that uses the [zstd seekable format].
/// You can create such a stream using [`ZstdWriterBuilder::with_seek_table`](crate::writer::ZstdWriterBuilder::with_seek_table).
///
/// There are two situations where seeking can take advantage of a seek
/// table:
///
/// 1. When a seek table is provided up-front using [`ZstdReaderBuilder::with_seek_table`].
///    See [`crate::table::read_seek_table`] for reading a seek table
///    from a reader.
/// 2. When rewinding to a previously-decompressed frame. Frame offsets are
///    automatically recorded during decompression.
///
/// Even if a seek table is used, seeking will still need to rewind to
/// the start of a frame, then decompress until reaching the target offset.
///
/// [zstd seekable format]: https://github.com/facebook/zstd/tree/51eb7daf39c8e8a7c338ba214a9d4e2a6a086826/contrib/seekable_format
pub struct ZstdReader<'dict, R> {
    reader: R,
    decoder: ZstdFramedDecoder<'dict>,
    buffer: crate::buffer::FixedBuffer<Vec<u8>>,
    current_pos: u64,
}

impl<R> ZstdReader<'_, std::io::BufReader<R>> {
    /// Create a new zstd reader that decompresses the zstd stream from
    /// the underlying reader. The provided reader will be wrapped with
    /// an appropriately-sized buffer.
    pub fn builder(reader: R) -> ZstdReaderBuilder<std::io::BufReader<R>>
    where
        R: std::io::Read,
    {
        ZstdReaderBuilder::new(reader)
    }

    /// Create a new zstd reader that decompresses the zstd stream from
    /// the underlying reader. The underlying reader must implement
    /// [`std::io::BufRead`], and its buffer will be used directly. When in
    /// doubt, use [`ZstdReader::builder`], which uses an appropriate
    /// buffer size for decompressing a zstd stream.
    pub fn builder_buffered(reader: R) -> ZstdReaderBuilder<R> {
        ZstdReaderBuilder::with_buffered(reader)
    }
}

impl<R> ZstdReader<'_, R> {
    /// Jump forward, decoding and consuming `length` decompressed bytes
    /// from the zstd stream.
    fn jump_forward(&mut self, mut length: u64) -> std::io::Result<()>
    where
        R: std::io::BufRead,
    {
        while length > 0 {
            // Decode some data from the underyling reader
            let decoded = self.fill_buf()?;

            // Return an error if we don't have any more data to decode
            if decoded.is_empty() {
                return Err(std::io::Error::new(
                    std::io::ErrorKind::UnexpectedEof,
                    "reached eof while trying to decode to offset",
                ));
            }

            // Consume the data (up to the remaining length we should jump)
            let decoded_len = u64::try_from(decoded.len()).map_err(std::io::Error::other)?;
            let consumed_len = decoded_len.min(length);
            let consumed_len_usize =
                usize::try_from(consumed_len).map_err(std::io::Error::other)?;

            self.consume(consumed_len_usize);

            // Keep iterating until we've consumed enough to reach our target
            length -= consumed_len;
        }

        Ok(())
    }

    /// Decode and consume the entire zstd stream until reaching the end.
    /// Stops when reaching EOF (i.e. the underlying reader had no more data)
    fn jump_to_end(&mut self) -> std::io::Result<()>
    where
        R: std::io::BufRead,
    {
        loop {
            // Decode some data from the underlying reader
            let decoded = self.fill_buf()?;

            // If we didn't get any more data, we're done
            if decoded.is_empty() {
                break;
            }

            // Consume all the decoded data
            let decoded_len = decoded.len();
            self.consume(decoded_len);
        }

        Ok(())
    }
}

impl<R> std::io::Read for ZstdReader<'_, R>
where
    R: std::io::BufRead,
{
    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
        if buf.is_empty() {
            return Ok(0);
        }

        // Decode some data from the underlying reader
        let filled = self.fill_buf()?;

        // Get some of the decoded data, capped to `buf`'s length
        let consumable = filled.len().min(buf.len());

        // Copy the decoded data to `buf`
        buf[..consumable].copy_from_slice(&filled[..consumable]);

        // Consume the copied data
        self.consume(consumable);

        Ok(consumable)
    }
}

impl<R> std::io::BufRead for ZstdReader<'_, R>
where
    R: std::io::BufRead,
{
    fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
        loop {
            // Check if our buffer contais any data we can return
            if !self.buffer.uncommitted().is_empty() {
                // If it does, we're done
                break;
            }

            // Get some data from the underlying reader
            let decodable = self.reader.fill_buf()?;
            if decodable.is_empty() {
                // If the underlying reader doesn't have any more data,
                // then we're done
                break;
            }

            // Decode the data, and write it to `self.buffer`
            let consumed = self.decoder.decode(decodable, &mut self.buffer)?;

            // Tell the underlying reader that we read the subset of
            // data we decoded
            self.reader.consume(consumed);
        }

        // Return all the data we have in `self.buffer`
        Ok(self.buffer.uncommitted())
    }

    fn consume(&mut self, amt: usize) {
        // Tell the buffer that we've committed the data that was consumed
        self.buffer.commit(amt);

        // Advance the reader's position
        let amt_u64 = u64::try_from(amt).unwrap();
        self.current_pos += amt_u64;
    }
}

impl<R> std::io::Seek for ZstdReader<'_, R>
where
    R: std::io::BufRead + std::io::Seek,
{
    fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
        // Get the target position relative to the stream start
        let target_pos = match pos {
            std::io::SeekFrom::Start(offset) => {
                // Position is already relative to the start
                offset
            }
            std::io::SeekFrom::End(offset) => {
                // Seek to the last position we know about in the stream.
                let seek = self.decoder.prepare_seek_to_last_known_pos();

                if let Some(frame) = seek.seek_to_frame_start {
                    // We need to seek to the start of a frame

                    // Seek the underlying reader
                    self.reader
                        .seek(std::io::SeekFrom::Start(frame.compressed_pos))?;

                    // Update the decoder based on what frame we're now at
                    self.decoder.seeked_to_frame(frame)?;

                    // Update our internal position to align with the start of the frame
                    self.current_pos = frame.decompressed_pos;

                    // Clear the buffer
                    self.buffer.clear();
                }

                // Jump to the end of the stream
                self.jump_to_end()?;

                // The current position is now at the end, so
                // add the end offset
                self.current_pos
                    .checked_add_signed(offset)
                    .ok_or_else(|| std::io::Error::other("invalid seek offset"))?
            }
            std::io::SeekFrom::Current(offset) => {
                // Add the offset to the current position
                self.current_pos
                    .checked_add_signed(offset)
                    .ok_or_else(|| std::io::Error::other("invalid seek offset"))?
            }
        };

        // Consume any leftover data in `self.buffer`. This ensures that
        // the current position is in line with the underlying decoder
        self.consume(self.buffer.uncommitted().len());

        // Determine what we need to do to reach the target position
        let seek = self.decoder.prepare_seek_to_decompressed_pos(target_pos);

        if let Some(frame) = seek.seek_to_frame_start {
            // We need to seek to the start of a frame

            // Seek the underlying reader
            self.reader
                .seek(std::io::SeekFrom::Start(frame.compressed_pos))?;

            // Update the decoder based on what frame we're now at
            self.decoder.seeked_to_frame(frame)?;

            // Update our internal position to align with the start of the frame
            self.current_pos = frame.decompressed_pos;
        }

        // Seek the remaining distance (if any) to reach the target position
        self.jump_forward(seek.decompress_len)?;

        assert_eq!(self.current_pos, target_pos);
        Ok(self.current_pos)
    }
}

/// A builder that builds a [`ZstdReader`] from the provided reader.
pub struct ZstdReaderBuilder<R> {
    reader: R,
    table: ZstdSeekTable,
}

impl<R> ZstdReaderBuilder<std::io::BufReader<R>> {
    fn new(reader: R) -> Self
    where
        R: std::io::Read,
    {
        let reader = std::io::BufReader::with_capacity(zstd::zstd_safe::DCtx::in_size(), reader);
        ZstdReaderBuilder::with_buffered(reader)
    }
}

impl<R> ZstdReaderBuilder<R> {
    fn with_buffered(reader: R) -> Self {
        ZstdReaderBuilder {
            reader,
            table: ZstdSeekTable::empty(),
        }
    }

    /// Use the given seek table when seeking the resulting reader. This can
    /// greatly speed up seek operations when using a zstd stream that
    /// uses the [zstd seekable format].
    ///
    /// See [`crate::table::read_seek_table`] for reading a seek table.
    ///
    /// [zstd seekable format]: https://github.com/facebook/zstd/tree/51eb7daf39c8e8a7c338ba214a9d4e2a6a086826/contrib/seekable_format
    pub fn with_seek_table(mut self, table: ZstdSeekTable) -> Self {
        self.table = table;
        self
    }

    /// Build the reader.
    pub fn build(self) -> std::io::Result<ZstdReader<'static, R>> {
        let zstd_decoder = zstd::stream::raw::Decoder::new()?;
        let buffer = crate::buffer::FixedBuffer::new(vec![0; zstd::zstd_safe::DCtx::out_size()]);
        let decoder = ZstdFramedDecoder::new(zstd_decoder, self.table);

        Ok(ZstdReader {
            reader: self.reader,
            decoder,
            buffer,
            current_pos: 0,
        })
    }
}