zstd_framed/
async_reader.rs

1use crate::{
2    decoder::ZstdFramedDecoder,
3    table::{ZstdFrame, ZstdSeekTable},
4};
5
6pin_project_lite::pin_project! {
7    /// An reader that decompresses a zstd stream from an underlying
8    /// async reader. Works as either a `tokio` or `futures` reader if
9    /// their respective features are enabled.
10    ///
11    /// The underyling reader `R` should implement the following traits:
12    ///
13    /// - `tokio`
14    ///   - [`tokio::io::AsyncBufRead`] (required for [`tokio::io::AsyncRead`] and [`tokio::io::AsyncBufRead`] impls)
15    ///   - (Optional) [`tokio::io::AsyncSeek`] (used when calling [`.seekable()`](AsyncZstdReader::seekable))
16    /// - `futures`
17    ///   - [`futures::AsyncBufRead`] (required for [`futures::AsyncRead`] and [`futures::AsyncBufRead`] impls)
18    ///   - (Optional) [`futures::AsyncSeek`] (used when calling [`.seekable()`](AsyncZstdReader::seekable))
19    ///
20    /// For sync I/O support, see [`crate::ZstdReader`].
21    ///
22    /// ## Construction
23    ///
24    /// Create a builder using [`AsyncZstdReader::builder_tokio`] (recommended
25    /// for `tokio`) or [`AsyncZstdReader::builder_futures`] (recommended for
26    /// `futures`); or use [`ZstdReader::builder_buffered`] to use a custom
27    /// buffer for either. See [`ZstdReaderBuilder`] for build options. Call
28    /// [`AsyncZstdReaderBuilder::build`] to build the
29    /// [`AsyncZstdReader`] instance.
30    ///
31    /// ```
32    /// # #[cfg(feature = "tokio")]
33    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
34    /// # let compressed_file: &[u8] = &[];
35    /// // Tokio example
36    /// let reader = zstd_framed::AsyncZstdReader::builder_tokio(compressed_file)
37    ///     // .with_seek_table(table) // Provide a seek table if available
38    ///     .build()?;
39    /// # Ok(())
40    /// # }
41    /// # #[cfg(not(feature = "tokio"))]
42    /// # fn main() { }
43    /// ```
44    ///
45    /// ```
46    /// # #[cfg(feature = "futures")]
47    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
48    /// # let compressed_file: &[u8] = &[];
49    /// // futures example
50    /// let reader = zstd_framed::AsyncZstdReader::builder_futures(compressed_file)
51    ///     // .with_seek_table(table) // Provide a seek table if available
52    ///     .build()?;
53    /// # Ok(())
54    /// # }
55    /// # #[cfg(not(feature = "futures"))]
56    /// # fn main() { }
57    /// ```
58    ///
59    /// ## Buffering
60    ///
61    /// The decompressed zstd output is always buffered internally. Since the
62    /// reader must also implement [`tokio::io::AsyncBufRead`] /
63    /// [`futures::AsyncBufRead`], the compressed input must also be buffered.
64    ///
65    /// [`AsyncZstdReader::builder_tokio`] and
66    /// [`AsyncZstdReader::builder_futures`] will wrap any reader implmenting
67    /// [`tokio::io::AsyncRead`] or [`futures::AsyncRead`] (respectively)
68    /// with a recommended buffer size for the input stream. For more control
69    /// over how the input gets buffered, you can instead use
70    /// [`AsyncZstdReader::builder_buffered`].
71    ///
72    /// ## Seeking
73    ///
74    /// If the underlying reader is seekable (i.e. it implements either
75    /// [`tokio::io::AsyncSeek`] or [`futures::AsyncSeek`]), you can call
76    /// [`.seekable()`](`AsyncZstdReader::seekable`) to convert it to a seekable reader. See
77    /// [`AsyncZstdSeekableReader`] for notes and caveats about seeking.
78    pub struct AsyncZstdReader<'dict, R> {
79        #[pin]
80        reader: R,
81        decoder: ZstdFramedDecoder<'dict>,
82        buffer: crate::buffer::FixedBuffer<Vec<u8>>,
83        current_pos: u64,
84    }
85}
86
87impl<'dict, R> AsyncZstdReader<'dict, R> {
88    /// Create a new zstd reader that decompresses the zstd stream from
89    /// the underlying Tokio reader. The provided reader will be wrapped
90    /// with an appropriately-sized buffer.
91    #[cfg(feature = "tokio")]
92    pub fn builder_tokio(reader: R) -> ZstdReaderBuilder<tokio::io::BufReader<R>>
93    where
94        R: tokio::io::AsyncRead,
95    {
96        ZstdReaderBuilder::new_tokio(reader)
97    }
98
99    /// Create a new zstd reader that decompresses the zstd stream from
100    /// the underlying `futures` reader. The provided reader will be
101    /// wrapped with an appropriately-sized buffer.
102    #[cfg(feature = "futures")]
103    pub fn builder_futures(reader: R) -> ZstdReaderBuilder<futures::io::BufReader<R>>
104    where
105        R: futures::AsyncRead,
106    {
107        ZstdReaderBuilder::new_futures(reader)
108    }
109
110    /// Create a new zstd reader that decompresses the zstd stream from
111    /// the underlying reader. The underlying reader must implement
112    /// either [`tokio::io::AsyncBufRead`] or [`futures::AsyncBufRead`],
113    /// and its buffer will be used directly. When in doubt, use
114    /// one of the other builder methods to use an appropriate buffer size
115    /// for decompressing a zstd stream.
116    pub fn builder_buffered(reader: R) -> ZstdReaderBuilder<R> {
117        ZstdReaderBuilder::with_buffered(reader)
118    }
119
120    /// Wrap the reader with [`AsyncZstdSeekableReader`], which adds support
121    /// for seeking if the underlying reader supports seeking.
122    pub fn seekable(self) -> AsyncZstdSeekableReader<'dict, R> {
123        AsyncZstdSeekableReader {
124            reader: self,
125            pending_seek: None,
126        }
127    }
128
129    /// Decode and consume the entire zstd stream until reaching the end.
130    /// Stops when reaching EOF (i.e. the underlying reader had no more data)
131    #[cfg(feature = "tokio")]
132    fn poll_jump_to_end_tokio(
133        mut self: std::pin::Pin<&mut Self>,
134        cx: &mut std::task::Context<'_>,
135    ) -> std::task::Poll<std::io::Result<()>>
136    where
137        R: tokio::io::AsyncBufRead,
138    {
139        use tokio::io::AsyncBufRead as _;
140
141        loop {
142            // Decode some data from the underlying reader
143            let decoded = ready!(self.as_mut().poll_fill_buf(cx))?;
144
145            // If we didn't get any more data, we're done
146            if decoded.is_empty() {
147                break;
148            }
149
150            // Consume all the decoded data
151            let decoded_len = decoded.len();
152            self.as_mut().consume(decoded_len);
153        }
154
155        std::task::Poll::Ready(Ok(()))
156    }
157
158    /// Decode and consume the entire zstd stream until reaching the end.
159    /// Stops when reaching EOF (i.e. the underlying reader had no more data)
160    #[cfg(feature = "futures")]
161    fn poll_jump_to_end_futures(
162        mut self: std::pin::Pin<&mut Self>,
163        cx: &mut std::task::Context<'_>,
164    ) -> std::task::Poll<std::io::Result<()>>
165    where
166        R: futures::AsyncBufRead,
167    {
168        use futures::AsyncBufRead as _;
169
170        loop {
171            // Decode some data from the underlying reader
172            let decoded = ready!(self.as_mut().poll_fill_buf(cx))?;
173
174            // If we didn't get any more data, we're done
175            if decoded.is_empty() {
176                break;
177            }
178
179            // Consume all the decoded data
180            let decoded_len = decoded.len();
181            self.as_mut().consume(decoded_len);
182        }
183
184        std::task::Poll::Ready(Ok(()))
185    }
186}
187
188#[cfg(feature = "tokio")]
189impl<R> tokio::io::AsyncBufRead for AsyncZstdReader<'_, R>
190where
191    R: tokio::io::AsyncBufRead,
192{
193    fn poll_fill_buf(
194        mut self: std::pin::Pin<&mut Self>,
195        cx: &mut std::task::Context<'_>,
196    ) -> std::task::Poll<std::io::Result<&[u8]>> {
197        use crate::buffer::Buffer as _;
198
199        loop {
200            let mut this = self.as_mut().project();
201
202            // Check if our buffer contais any data we can return
203            if !this.buffer.uncommitted().is_empty() {
204                // If it does, we're done
205                break;
206            }
207
208            // Get some data from the underlying reader
209            let decodable = ready!(this.reader.as_mut().poll_fill_buf(cx))?;
210            if decodable.is_empty() {
211                // If the underlying reader doesn't have any more data,
212                // then we're done
213                break;
214            }
215
216            // Decode the data, and write it to `self.buffer`
217            let consumed = this.decoder.decode(decodable, this.buffer)?;
218
219            // Tell the underlying reader that we read the subset of
220            // data we decoded
221            this.reader.consume(consumed);
222        }
223
224        // Return all the data we have in `self.buffer`
225        std::task::Poll::Ready(Ok(self.project().buffer.uncommitted()))
226    }
227
228    fn consume(self: std::pin::Pin<&mut Self>, amt: usize) {
229        use crate::buffer::Buffer as _;
230
231        let this = self.project();
232
233        // Tell the buffer that we've committed the data that was consumed
234        this.buffer.commit(amt);
235
236        // Advance the reader's position
237        let amt_u64: u64 = amt.try_into().unwrap();
238        *this.current_pos += amt_u64;
239    }
240}
241
242#[cfg(feature = "tokio")]
243impl<R> tokio::io::AsyncRead for AsyncZstdReader<'_, R>
244where
245    R: tokio::io::AsyncBufRead,
246{
247    fn poll_read(
248        mut self: std::pin::Pin<&mut Self>,
249        cx: &mut std::task::Context<'_>,
250        buf: &mut tokio::io::ReadBuf<'_>,
251    ) -> std::task::Poll<std::io::Result<()>> {
252        use tokio::io::AsyncBufRead as _;
253
254        // Decode some data from the underlying reader
255        let filled = ready!(self.as_mut().poll_fill_buf(cx))?;
256
257        // Get some of the decoded data, capped to `buf`'s length
258        let consumable = filled.len().min(buf.remaining());
259
260        // Copy the decoded data to `buf`
261        buf.put_slice(&filled[..consumable]);
262
263        // Consume the copied data
264        self.consume(consumable);
265
266        std::task::Poll::Ready(Ok(()))
267    }
268}
269
270#[cfg(feature = "futures")]
271impl<R> futures::AsyncBufRead for AsyncZstdReader<'_, R>
272where
273    R: futures::io::AsyncBufRead,
274{
275    fn poll_fill_buf(
276        mut self: std::pin::Pin<&mut Self>,
277        cx: &mut std::task::Context<'_>,
278    ) -> std::task::Poll<std::io::Result<&[u8]>> {
279        use crate::buffer::Buffer as _;
280
281        loop {
282            let mut this = self.as_mut().project();
283
284            // Check if our buffer contais any data we can return
285            if !this.buffer.uncommitted().is_empty() {
286                // If it does, we're done
287                break;
288            }
289
290            // Get some data from the underlying reader
291            let decodable = ready!(this.reader.as_mut().poll_fill_buf(cx))?;
292            if decodable.is_empty() {
293                // If the underlying reader doesn't have any more data,
294                // then we're done
295                break;
296            }
297
298            // Decode the data, and write it to `self.buffer`
299            let consumed = this.decoder.decode(decodable, this.buffer)?;
300
301            // Tell the underlying reader that we read the subset of
302            // data we decoded
303            this.reader.consume(consumed);
304        }
305
306        // Return all the data we have in `self.buffer`
307        std::task::Poll::Ready(Ok(self.project().buffer.uncommitted()))
308    }
309
310    fn consume(self: std::pin::Pin<&mut Self>, amt: usize) {
311        use crate::buffer::Buffer as _;
312
313        let this = self.project();
314
315        // Tell the buffer that we've committed the data that was consumed
316        this.buffer.commit(amt);
317
318        // Advance the reader's position
319        let amt_u64: u64 = amt.try_into().unwrap();
320        *this.current_pos += amt_u64;
321    }
322}
323
324#[cfg(feature = "futures")]
325impl<R> futures::AsyncRead for AsyncZstdReader<'_, R>
326where
327    R: futures::AsyncBufRead,
328{
329    fn poll_read(
330        mut self: std::pin::Pin<&mut Self>,
331        cx: &mut std::task::Context<'_>,
332        buf: &mut [u8],
333    ) -> std::task::Poll<std::io::Result<usize>> {
334        use futures::AsyncBufRead as _;
335
336        if buf.is_empty() {
337            return std::task::Poll::Ready(Ok(0));
338        }
339
340        // Decode some data from the underlying reader
341        let filled = ready!(self.as_mut().poll_fill_buf(cx))?;
342
343        // Get some of the decoded data, capped to `buf`'s length
344        let consumable = filled.len().min(buf.len());
345
346        // Copy the decoded data to `buf`
347        buf[..consumable].copy_from_slice(&filled[..consumable]);
348
349        // Consume the copied data
350        self.consume(consumable);
351
352        std::task::Poll::Ready(Ok(consumable))
353    }
354}
355
356pin_project_lite::pin_project! {
357    /// A wrapper around [`AsyncZstdReader`] with extra support for seeking.
358    /// Created via the method [`AsyncZstdReader::seekable`].
359    ///
360    /// The underlying reader `R` should implement the following traits:
361    ///
362    /// - `tokio`
363    ///   - [`tokio::io::AsyncBufRead`] + [`tokio::io::AsyncSeek`] (required for [`tokio::io::AsyncRead`], [`tokio::io::AsyncBufRead`], and [`tokio::io::AsyncSeek`] impls)
364    /// - `futures`
365    ///   - [`futures::AsyncBufRead`] + [`futures::AsyncSeek`] (required for [`futures::AsyncRead`], [`futures::AsyncBufRead`], and [`tokio::io::AsyncSeek`] impls)
366    ///
367    /// **By default, seeking
368    /// within the stream will linearly decompress
369    /// until reaching the target!**
370    ///
371    /// Seeking can do a lot better when the underlying stream is broken up
372    /// into multiple frames, such as a stream that uses the [zstd seekable format].
373    /// You can create such a stream using [`ZstdWriterBuilder::with_seek_table`](crate::async_writer::ZstdWriterBuilder::with_seek_table).
374    ///
375    /// There are two situations where seeking can take advantage of a seek
376    /// table:
377    ///
378    /// 1. When a seek table is provided up-front using [`ZstdReaderBuilder::with_seek_table`].
379    ///    See [`crate::table::read_seek_table`] for reading a seek table
380    ///    from a reader (there are also async-friendly functions available).
381    /// 2. When rewinding to a previously-decompressed frame. Frame offsets are
382    ///    automatically recorded during decompression.
383    ///
384    /// Even if a seek table is used, seeking will still need to rewind to
385    /// the start of a frame, then decompress until reaching the target offset.
386    ///
387    /// [zstd seekable format]: https://github.com/facebook/zstd/tree/51eb7daf39c8e8a7c338ba214a9d4e2a6a086826/contrib/seekable_format
388    pub struct AsyncZstdSeekableReader<'dict, R> {
389        #[pin]
390        reader: AsyncZstdReader<'dict, R>,
391        pending_seek: Option<PendingSeek>,
392    }
393}
394
395impl<R> AsyncZstdSeekableReader<'_, R> {
396    /// If a seek operation was started with [`tokio::io::AsyncSeek::start_seek`]
397    /// but wasn't polled to completion, "undo" the seek by seeking
398    /// back to where we were in the zstd stream.
399    #[cfg(feature = "tokio")]
400    fn poll_cancel_seek_tokio(
401        self: std::pin::Pin<&mut Self>,
402        cx: &mut std::task::Context<'_>,
403    ) -> std::task::Poll<std::io::Result<()>>
404    where
405        R: tokio::io::AsyncBufRead + tokio::io::AsyncSeek,
406    {
407        use crate::buffer::Buffer as _;
408        use tokio::io::AsyncBufRead as _;
409
410        let mut this = self.project();
411
412        // Iterate until `self.pending_seek` is unset. Each iteration
413        // should make some progress based on the current pending seek
414        // state (or should return `Poll::Pending`)
415        loop {
416            let Some(pending_seek) = *this.pending_seek else {
417                // No pending seek, which means we're done!
418                return std::task::Poll::Ready(Ok(()));
419            };
420
421            match pending_seek.state {
422                PendingSeekState::Starting => {
423                    // Seek just started. There's nothing to undo, so
424                    // just clear the pending seek
425                    *this.pending_seek = None;
426                }
427                PendingSeekState::SeekingToLastFrame { .. }
428                | PendingSeekState::JumpingToEnd { .. }
429                | PendingSeekState::SeekingToTarget { .. }
430                | PendingSeekState::SeekingToFrame { .. }
431                | PendingSeekState::JumpingForward { .. } => {
432                    // Seek is in progress
433
434                    // Consume any leftover data in `self.buffer`. This ensures
435                    // that the current position is in line with the
436                    // underlying decoder
437                    let consumable = this.reader.buffer.uncommitted().len();
438                    this.reader.as_mut().consume(consumable);
439
440                    // Determine what we need to do to reach the target position
441                    let seek = this
442                        .reader
443                        .decoder
444                        .prepare_seek_to_decompressed_pos(pending_seek.initial_pos);
445
446                    if let Some(frame) = seek.seek_to_frame_start {
447                        // We need to seek to the start of a frame
448
449                        // Transition the state to indicate we're seeking
450                        // to the start of a frame
451                        *this.pending_seek = Some(PendingSeek {
452                            state: PendingSeekState::RestoringSeekToFrame {
453                                frame,
454                                decompress_len: seek.decompress_len,
455                            },
456                            ..pending_seek
457                        });
458
459                        // Submit a seek job to the underlying reader
460                        let reader = this.reader.as_mut().project().reader;
461                        let result =
462                            reader.start_seek(std::io::SeekFrom::Start(frame.compressed_pos));
463
464                        match result {
465                            Ok(_) => {}
466                            Err(error) => {
467                                // Trying to seek the underlying reader
468                                // failed, so clear the pending seek and bail
469                                *this.pending_seek = None;
470                                return std::task::Poll::Ready(Err(std::io::Error::other(
471                                    format!("failed to cancel in-progress zstd seek: {error}"),
472                                )));
473                            }
474                        }
475                    } else {
476                        // We just need to keep decompressing to reach the
477                        // target position
478
479                        // Transition to a state to indicate how many
480                        // bytes we need to consume
481                        *this.pending_seek = Some(PendingSeek {
482                            state: PendingSeekState::RestoringJumpForward {
483                                decompress_len: seek.decompress_len,
484                            },
485                            ..pending_seek
486                        });
487                    }
488                }
489                PendingSeekState::RestoringSeekToFrame {
490                    frame,
491                    decompress_len,
492                } => {
493                    // We're in the process of restoring to the previous
494                    // seek position, and need to seek to the start of
495                    // a frame
496
497                    let reader = this.reader.as_mut().project();
498
499                    // Poll until the seek completes
500                    let result = ready!(reader.reader.poll_complete(cx));
501
502                    match result {
503                        Ok(_) => {}
504                        Err(error) => {
505                            // Seeking the underlying reader failed, so
506                            // clear the pending seek and bail
507                            *this.pending_seek = None;
508                            return std::task::Poll::Ready(Err(std::io::Error::other(format!(
509                                "failed to cancel in-progress zstd seek: {error}"
510                            ))));
511                        }
512                    }
513
514                    // Update our internal position to align with the start of the frame
515                    *reader.current_pos = frame.decompressed_pos;
516
517                    // Update the decoder based on what frame we're now at
518                    let result = reader.decoder.seeked_to_frame(frame);
519                    match result {
520                        Ok(_) => {}
521                        Err(error) => {
522                            // Error from the decoder, so clear the pending
523                            // seek and bail
524                            *this.pending_seek = None;
525                            return std::task::Poll::Ready(Err(std::io::Error::other(format!(
526                                "failed to cancel in-progress zstd seek: {error}"
527                            ))));
528                        }
529                    }
530
531                    // Update our state to decompress as much data as we
532                    // need to reach the initial position again
533                    *this.pending_seek = Some(PendingSeek {
534                        state: PendingSeekState::RestoringJumpForward { decompress_len },
535                        ..pending_seek
536                    });
537                }
538                PendingSeekState::RestoringJumpForward { decompress_len: 0 } => {
539                    // We finished getting back to the initial position!
540                    // Clear the pending seek
541
542                    assert_eq!(pending_seek.initial_pos, this.reader.current_pos);
543                    *this.pending_seek = None;
544                }
545                PendingSeekState::RestoringJumpForward { decompress_len } => {
546                    // We have to decompress some data to reach the initial
547                    // position
548
549                    // Try to decompress some data from the underlying
550                    // reader
551                    let result = ready!(this.reader.as_mut().poll_fill_buf(cx));
552                    let filled = match result {
553                        Ok(filled) => filled,
554                        Err(error) => {
555                            // Failed to decompress from the underlying
556                            // reader, so clear the pending seek and bail
557                            *this.pending_seek = None;
558                            return std::task::Poll::Ready(Err(std::io::Error::other(format!(
559                                "failed to cancel in-progress zstd seek: {error}"
560                            ))));
561                        }
562                    };
563
564                    if filled.is_empty() {
565                        // The underlying reader didn't give us any data, which
566                        // means we hit EOF while trying to get back to the
567                        // initial position. Clear the pending seek and bail
568
569                        *this.pending_seek = None;
570                        return std::task::Poll::Ready(Err(std::io::Error::new(
571                            std::io::ErrorKind::UnexpectedEof,
572                            "reached eof while trying to cancel in-progress zstd seek",
573                        )));
574                    }
575
576                    // Consume as much data as we can to try and reach
577                    // the total data to decompress
578                    let filled_len_u64: u64 = filled.len().try_into().unwrap();
579                    let jump_len = filled_len_u64.min(decompress_len);
580                    let jump_len_usize: usize = jump_len.try_into().unwrap();
581                    this.reader.as_mut().consume(jump_len_usize);
582
583                    // Update the state based on how much more data
584                    // we have left to decompress
585                    *this.pending_seek = Some(PendingSeek {
586                        state: PendingSeekState::RestoringJumpForward {
587                            decompress_len: decompress_len - jump_len,
588                        },
589                        ..pending_seek
590                    });
591                }
592            }
593        }
594    }
595
596    /// If a seek operation was started with [`futures::AsyncSeek::poll_seek`]
597    /// but wasn't polled to completion, "undo" the seek by seeking
598    /// back to where we were in the zstd stream.
599    #[cfg(feature = "futures")]
600    fn poll_cancel_seek_futures(
601        self: std::pin::Pin<&mut Self>,
602        cx: &mut std::task::Context<'_>,
603    ) -> std::task::Poll<std::io::Result<()>>
604    where
605        R: futures::AsyncBufRead + futures::AsyncSeek,
606    {
607        use crate::buffer::Buffer as _;
608        use futures::AsyncBufRead as _;
609
610        let mut this = self.project();
611
612        // Iterate until `self.pending_seek` is unset. Each iteration
613        // should make some progress based on the current pending seek
614        // state (or should return `Poll::Pending`)
615        loop {
616            let Some(pending_seek) = *this.pending_seek else {
617                // No pending seek, which means we're done!
618                return std::task::Poll::Ready(Ok(()));
619            };
620
621            match pending_seek.state {
622                PendingSeekState::Starting => {
623                    // Seek just started. There's nothing to undo, so
624                    // just clear the pending seek
625                    *this.pending_seek = None;
626                }
627                PendingSeekState::SeekingToLastFrame { .. }
628                | PendingSeekState::JumpingToEnd { .. }
629                | PendingSeekState::SeekingToTarget { .. }
630                | PendingSeekState::SeekingToFrame { .. }
631                | PendingSeekState::JumpingForward { .. } => {
632                    // Seek is in progress
633
634                    // Consume any leftover data in `self.buffer`. This ensures
635                    // that the current position is in line with the
636                    // underlying decoder
637                    let consumable = this.reader.buffer.uncommitted().len();
638                    this.reader.as_mut().consume(consumable);
639
640                    // Determine what we need to do to reach the target position
641                    let seek = this
642                        .reader
643                        .decoder
644                        .prepare_seek_to_decompressed_pos(pending_seek.initial_pos);
645
646                    if let Some(frame) = seek.seek_to_frame_start {
647                        // We need to seek to the start of a frame
648
649                        // Transition the state to indicate we're seeking
650                        // to the start of a frame
651                        *this.pending_seek = Some(PendingSeek {
652                            state: PendingSeekState::RestoringSeekToFrame {
653                                frame,
654                                decompress_len: seek.decompress_len,
655                            },
656                            ..pending_seek
657                        });
658                    } else {
659                        // We just need to keep decompressing to reach the
660                        // target position
661
662                        // Transition to a state to indicate how many
663                        // bytes we need to consume
664                        *this.pending_seek = Some(PendingSeek {
665                            state: PendingSeekState::RestoringJumpForward {
666                                decompress_len: seek.decompress_len,
667                            },
668                            ..pending_seek
669                        });
670                    }
671                }
672                PendingSeekState::RestoringSeekToFrame {
673                    frame,
674                    decompress_len,
675                } => {
676                    // We're in the process of restoring to the previous
677                    // seek position, and need to seek to the start of
678                    // a frame
679
680                    let reader = this.reader.as_mut().project();
681
682                    // Poll until we finish seeking the underlying reader
683                    let result = ready!(reader
684                        .reader
685                        .poll_seek(cx, std::io::SeekFrom::Start(frame.compressed_pos)));
686
687                    match result {
688                        Ok(_) => {}
689                        Err(error) => {
690                            // Seeking the underlying reader failed, so
691                            // clear the pending seek and bail
692                            *this.pending_seek = None;
693                            return std::task::Poll::Ready(Err(std::io::Error::other(format!(
694                                "failed to cancel in-progress zstd seek: {error}"
695                            ))));
696                        }
697                    }
698
699                    // Update our internal position to align with the start of the frame
700                    *reader.current_pos = frame.decompressed_pos;
701
702                    // Update the decoder based on what frame we're now at
703                    let result = reader.decoder.seeked_to_frame(frame);
704                    match result {
705                        Ok(_) => {}
706                        Err(error) => {
707                            // Error from the decoder, so clear the pending
708                            // seek and bail
709                            *this.pending_seek = None;
710                            return std::task::Poll::Ready(Err(std::io::Error::other(format!(
711                                "failed to cancel in-progress zstd seek: {error}"
712                            ))));
713                        }
714                    }
715
716                    // Update our state to decompress as much data as we
717                    // need to reach the initial position again
718                    *this.pending_seek = Some(PendingSeek {
719                        state: PendingSeekState::RestoringJumpForward { decompress_len },
720                        ..pending_seek
721                    });
722                }
723                PendingSeekState::RestoringJumpForward { decompress_len: 0 } => {
724                    // We finished getting back to the initial position!
725                    // Clear the pending seek
726
727                    assert_eq!(pending_seek.initial_pos, this.reader.current_pos);
728                    *this.pending_seek = None;
729                }
730                PendingSeekState::RestoringJumpForward { decompress_len } => {
731                    // We have to decompress some data to reach the initial
732                    // position
733
734                    // Try to decompress some data from the underlying
735                    // reader
736                    let result = ready!(this.reader.as_mut().poll_fill_buf(cx));
737                    let filled = match result {
738                        Ok(filled) => filled,
739                        Err(error) => {
740                            // Failed to decompress from the underlying
741                            // reader, so clear the pending seek and bail
742                            *this.pending_seek = None;
743                            return std::task::Poll::Ready(Err(std::io::Error::other(format!(
744                                "failed to cancel in-progress zstd seek: {error}"
745                            ))));
746                        }
747                    };
748
749                    if filled.is_empty() {
750                        // The underlying reader didn't give us any data, which
751                        // means we hit EOF while trying to get back to the
752                        // initial position. Clear the pending seek and bail
753
754                        *this.pending_seek = None;
755                        return std::task::Poll::Ready(Err(std::io::Error::new(
756                            std::io::ErrorKind::UnexpectedEof,
757                            "reached eof while trying to cancel in-progress zstd seek",
758                        )));
759                    }
760
761                    // Consume as much data as we can to try and reach
762                    // the total data to decompress
763                    let filled_len_u64: u64 = filled.len().try_into().unwrap();
764                    let jump_len = filled_len_u64.min(decompress_len);
765                    let jump_len_usize: usize = jump_len.try_into().unwrap();
766                    this.reader.as_mut().consume(jump_len_usize);
767
768                    // Update the state based on how much more data
769                    // we have left to decompress
770                    *this.pending_seek = Some(PendingSeek {
771                        state: PendingSeekState::RestoringJumpForward {
772                            decompress_len: decompress_len - jump_len,
773                        },
774                        ..pending_seek
775                    });
776                }
777            }
778        }
779    }
780}
781
782#[cfg(feature = "tokio")]
783impl<R> tokio::io::AsyncBufRead for AsyncZstdSeekableReader<'_, R>
784where
785    R: tokio::io::AsyncBufRead + tokio::io::AsyncSeek,
786{
787    fn poll_fill_buf(
788        mut self: std::pin::Pin<&mut Self>,
789        cx: &mut std::task::Context<'_>,
790    ) -> std::task::Poll<std::io::Result<&[u8]>> {
791        // Cancel any in-progress seeks
792        ready!(self.as_mut().poll_cancel_seek_tokio(cx))?;
793
794        // Defer to the underlying implementation
795        let this = self.project();
796        this.reader.poll_fill_buf(cx)
797    }
798
799    fn consume(self: std::pin::Pin<&mut Self>, amt: usize) {
800        let this = self.project();
801
802        // Ensure we aren't trying to seek before removing any
803        // data from the buffer
804        assert!(
805            this.pending_seek.is_none(),
806            "tried to consume from buffer while seeking"
807        );
808
809        // Defer to the underlying implementation
810        this.reader.consume(amt);
811    }
812}
813
814#[cfg(feature = "tokio")]
815impl<R> tokio::io::AsyncRead for AsyncZstdSeekableReader<'_, R>
816where
817    R: tokio::io::AsyncBufRead + tokio::io::AsyncSeek,
818{
819    fn poll_read(
820        mut self: std::pin::Pin<&mut Self>,
821        cx: &mut std::task::Context<'_>,
822        buf: &mut tokio::io::ReadBuf<'_>,
823    ) -> std::task::Poll<std::io::Result<()>> {
824        use tokio::io::AsyncBufRead as _;
825
826        // Decode some data from the underlying reader
827        let filled = ready!(self.as_mut().poll_fill_buf(cx))?;
828
829        // Get some of the decoded data, capped to `buf`'s length
830        let consumable = filled.len().min(buf.remaining());
831
832        // Copy the decoded data to `buf`
833        buf.put_slice(&filled[..consumable]);
834
835        // Consume the copied data
836        self.consume(consumable);
837
838        std::task::Poll::Ready(Ok(()))
839    }
840}
841
842#[cfg(feature = "tokio")]
843impl<R> tokio::io::AsyncSeek for AsyncZstdSeekableReader<'_, R>
844where
845    R: tokio::io::AsyncBufRead + tokio::io::AsyncSeek,
846{
847    fn start_seek(
848        self: std::pin::Pin<&mut Self>,
849        position: std::io::SeekFrom,
850    ) -> std::io::Result<()> {
851        let mut this = self.project();
852
853        // Ensure there isn't another seek in progress first
854        if this.pending_seek.is_some() {
855            return Err(std::io::Error::other("seek already in progress"));
856        }
857
858        // Transition to the "starting" seek state
859        *this.pending_seek = Some(PendingSeek {
860            initial_pos: this.reader.as_mut().current_pos,
861            seek: position,
862            state: PendingSeekState::Starting,
863        });
864        Ok(())
865    }
866
867    fn poll_complete(
868        mut self: std::pin::Pin<&mut Self>,
869        cx: &mut std::task::Context<'_>,
870    ) -> std::task::Poll<std::io::Result<u64>> {
871        use crate::buffer::Buffer as _;
872        use tokio::io::AsyncBufRead as _;
873
874        // Iterate until `self.pending_seek` is unset. Each iteration
875        // should make some progress based on the current pending seek
876        // state (or should return `Poll::Pending`)
877        loop {
878            let mut this = self.as_mut().project();
879
880            let Some(pending_seek) = *this.pending_seek else {
881                return std::task::Poll::Ready(Ok(this.reader.current_pos));
882            };
883
884            match pending_seek.state {
885                PendingSeekState::Starting => {
886                    // Seek is starting. The first step is to determine
887                    // the seek target relative to the start of the stream
888
889                    match pending_seek.seek {
890                        std::io::SeekFrom::Start(offset) => {
891                            // The offset is already relatve to the start,
892                            // so transition to the state to start seeking
893                            *this.pending_seek = Some(PendingSeek {
894                                state: PendingSeekState::SeekingToTarget { target_pos: offset },
895                                ..pending_seek
896                            });
897                        }
898                        std::io::SeekFrom::End(end_offset) => {
899                            // To determine the offset relative to the start,
900                            // we first need to reach the end of the stream.
901
902                            // Determine the best way to reach the last
903                            // position we know about in the stream.
904                            let seek = this.reader.decoder.prepare_seek_to_last_known_pos();
905
906                            if let Some(frame) = seek.seek_to_frame_start {
907                                // We have a frame to seek to
908
909                                // Start a job to seek to the start
910                                // of the last known frame
911                                let reader = this.reader.as_mut().project().reader;
912                                let result = reader
913                                    .start_seek(std::io::SeekFrom::Start(frame.compressed_pos));
914
915                                match result {
916                                    Ok(_) => {}
917                                    Err(error) => {
918                                        // Trying to seek the underlying reader
919                                        // failed, so clear the pending seek and bail
920                                        *this.pending_seek = None;
921                                        return std::task::Poll::Ready(Err(std::io::Error::other(
922                                            format!(
923                                                "failed to cancel in-progress zstd seek: {error}"
924                                            ),
925                                        )));
926                                    }
927                                }
928
929                                // Transition to the state to seek before
930                                // jumping to the end of the stream
931                                *this.pending_seek = Some(PendingSeek {
932                                    state: PendingSeekState::SeekingToLastFrame {
933                                        frame,
934                                        end_offset,
935                                    },
936                                    ..pending_seek
937                                })
938                            } else {
939                                // No need to seek, so transition to the
940                                // state to jump to the end of the stream
941                                *this.pending_seek = Some(PendingSeek {
942                                    state: PendingSeekState::JumpingToEnd { end_offset },
943                                    ..pending_seek
944                                });
945                            }
946                        }
947                        std::io::SeekFrom::Current(offset) => {
948                            // Compute the offset relative to the current position
949                            let offset = this.reader.current_pos.checked_add_signed(offset);
950                            let offset = match offset {
951                                Some(offset) => offset,
952                                None => {
953                                    // Offset overflowed, so clear the pending
954                                    // seek and return an error
955                                    *this.pending_seek = None;
956                                    return std::task::Poll::Ready(Err(std::io::Error::other(
957                                        "invalid seek offset",
958                                    )));
959                                }
960                            };
961
962                            // Transition to the state to start seeking based
963                            // on the computed offset
964                            *this.pending_seek = Some(PendingSeek {
965                                state: PendingSeekState::SeekingToTarget { target_pos: offset },
966                                ..pending_seek
967                            });
968                        }
969                    }
970                }
971                PendingSeekState::SeekingToLastFrame { end_offset, frame } => {
972                    // We're seeking to the last (known) frame in the stream
973                    // before jumping to the end
974
975                    let reader = this.reader.as_mut().project();
976
977                    // Wait for the seek on the underlying reader to complete
978                    let result = ready!(reader.reader.poll_complete(cx));
979                    match result {
980                        Ok(_) => {}
981                        Err(error) => {
982                            // Seeking the underlying reader failed,
983                            // so clear the in-progress seek and bail
984                            *this.pending_seek = None;
985                            return std::task::Poll::Ready(Err(error));
986                        }
987                    };
988
989                    // Update the decoder based on what frame we're now at
990                    let result = reader.decoder.seeked_to_frame(frame);
991                    match result {
992                        Ok(_) => {}
993                        Err(error) => {
994                            // The decoder failed, so clear the
995                            // in-progress seek and bail
996                            *this.pending_seek = None;
997                            return std::task::Poll::Ready(Err(error));
998                        }
999                    }
1000
1001                    // Update our internal position to align with the
1002                    // start of the frame
1003                    *reader.current_pos = frame.decompressed_pos;
1004
1005                    // Clear the buffer
1006                    reader.buffer.clear();
1007
1008                    // Seek complete, so transition states to jump to
1009                    // the end of the stream
1010                    *this.pending_seek = Some(PendingSeek {
1011                        state: PendingSeekState::JumpingToEnd { end_offset },
1012                        ..pending_seek
1013                    });
1014                }
1015                PendingSeekState::JumpingToEnd { end_offset } => {
1016                    // Seek target is relative to the end of the stream, so
1017                    // we need to jump to the end of the stream before
1018                    // determing the target position
1019
1020                    // Try to jump to the end of stream
1021                    let result = ready!(this.reader.as_mut().poll_jump_to_end_tokio(cx));
1022                    match result {
1023                        Ok(_) => {}
1024                        Err(error) => {
1025                            // Jumping to the end failed, so cancel the
1026                            // in-progress seek and bail
1027                            *this.pending_seek = None;
1028                            return std::task::Poll::Ready(Err(error));
1029                        }
1030                    };
1031
1032                    // Now we're at the end of the stream, so we can now
1033                    // compute the target position
1034                    let target_pos = this.reader.current_pos.checked_add_signed(end_offset);
1035                    let target_pos = match target_pos {
1036                        Some(target_pos) => target_pos,
1037                        None => {
1038                            // Target position overflowed, so cancel the
1039                            // in-progress seek and bail
1040                            *this.pending_seek = None;
1041                            return std::task::Poll::Ready(Err(std::io::Error::other(
1042                                "invalid seek offset",
1043                            )));
1044                        }
1045                    };
1046
1047                    // Transition to the state to start seeking based
1048                    // on the computed offset
1049                    *this.pending_seek = Some(PendingSeek {
1050                        state: PendingSeekState::SeekingToTarget { target_pos },
1051                        ..pending_seek
1052                    });
1053                }
1054                PendingSeekState::SeekingToTarget { target_pos } => {
1055                    // We now know the relative position we're seeking to
1056
1057                    // Consume any leftover data in `self.buffer`. This ensures
1058                    // that the current position is in line with the
1059                    // underlying decoder
1060                    let consumable = this.reader.buffer.uncommitted().len();
1061                    this.reader.as_mut().consume(consumable);
1062
1063                    // Determine what we need to do to reach the target position
1064                    let seek = this
1065                        .reader
1066                        .decoder
1067                        .prepare_seek_to_decompressed_pos(target_pos);
1068
1069                    if let Some(frame) = seek.seek_to_frame_start {
1070                        // We need to seek to the start of a frame
1071
1072                        // Transition to the state so we poll until
1073                        // the seek completes
1074                        *this.pending_seek = Some(PendingSeek {
1075                            state: PendingSeekState::SeekingToFrame {
1076                                target_pos,
1077                                frame,
1078                                decompress_len: seek.decompress_len,
1079                            },
1080                            ..pending_seek
1081                        });
1082
1083                        let reader = this.reader.as_mut().project().reader;
1084
1085                        // Start a job to seek the underlying reader
1086                        let result =
1087                            reader.start_seek(std::io::SeekFrom::Start(frame.compressed_pos));
1088
1089                        match result {
1090                            Ok(_) => {}
1091                            Err(error) => {
1092                                // Trying to seek the underlying reader
1093                                // failed, so clear the in-progress seek
1094                                // and bail
1095                                *this.pending_seek = None;
1096                                return std::task::Poll::Ready(Err(error));
1097                            }
1098                        }
1099                    } else {
1100                        // We need to keep decoding bytes to reach the
1101                        // target position
1102
1103                        // Transition to the state so that we can keep
1104                        // decompressing until reaching the target position
1105                        *this.pending_seek = Some(PendingSeek {
1106                            state: PendingSeekState::JumpingForward {
1107                                target_pos,
1108                                decompress_len: seek.decompress_len,
1109                            },
1110                            ..pending_seek
1111                        });
1112                    }
1113                }
1114                PendingSeekState::SeekingToFrame {
1115                    target_pos,
1116                    frame,
1117                    decompress_len,
1118                } => {
1119                    // We're seeking to the start of a frame
1120
1121                    let reader = this.reader.as_mut().project();
1122
1123                    // Poll until the underlying reader finishes seeking
1124                    let result = ready!(reader.reader.poll_complete(cx));
1125                    match result {
1126                        Ok(_) => {}
1127                        Err(error) => {
1128                            // Seeking the underlying reader failed,
1129                            // so clear the in-progress seek and bail
1130                            *this.pending_seek = None;
1131                            return std::task::Poll::Ready(Err(error));
1132                        }
1133                    };
1134
1135                    // Update the decoder based on what frame we're now at
1136                    let result = reader.decoder.seeked_to_frame(frame);
1137                    match result {
1138                        Ok(_) => {}
1139                        Err(error) => {
1140                            // The decoder failed, so clear the
1141                            // in-progress seek and bail
1142                            *this.pending_seek = None;
1143                            return std::task::Poll::Ready(Err(error));
1144                        }
1145                    }
1146
1147                    // Update our internal position to align with the
1148                    // start of the frame
1149                    *reader.current_pos = frame.decompressed_pos;
1150
1151                    // Seek complete, so transition states to decompress
1152                    // until reaching the target position
1153                    *this.pending_seek = Some(PendingSeek {
1154                        state: PendingSeekState::JumpingForward {
1155                            target_pos,
1156                            decompress_len,
1157                        },
1158                        ..pending_seek
1159                    });
1160                }
1161                PendingSeekState::JumpingForward {
1162                    target_pos,
1163                    decompress_len: 0,
1164                } => {
1165                    // No more bytes to decompress, so at the target position!
1166                    // Clear the pending seek
1167                    assert_eq!(target_pos, this.reader.current_pos);
1168                    *this.pending_seek = None;
1169                }
1170                PendingSeekState::JumpingForward {
1171                    target_pos,
1172                    decompress_len,
1173                } => {
1174                    // We have some bytes to decompress before reaching
1175                    // the target position
1176
1177                    // Try to decompress some data from the underlying
1178                    // reader
1179                    let result = ready!(this.reader.as_mut().poll_fill_buf(cx));
1180                    let filled = match result {
1181                        Ok(filled) => filled,
1182                        Err(error) => {
1183                            // Failed to decompress from the underlying
1184                            // reader, so clear the pending seek and bail
1185                            *this.pending_seek = None;
1186                            return std::task::Poll::Ready(Err(error));
1187                        }
1188                    };
1189
1190                    if filled.is_empty() {
1191                        // The underlying reader didn't give us any data, which
1192                        // means we hit EOF while trying to seek. Clear the
1193                        // pending seek and bail
1194                        *this.pending_seek = None;
1195                        return std::task::Poll::Ready(Err(std::io::Error::new(
1196                            std::io::ErrorKind::UnexpectedEof,
1197                            "reached eof while trying to decode to offset",
1198                        )));
1199                    }
1200
1201                    // Consume as much data as we can to try and reach
1202                    // the total data to decompress
1203                    let filled_len_u64: u64 = filled.len().try_into().unwrap();
1204                    let jump_len = filled_len_u64.min(decompress_len);
1205                    let jump_len_usize: usize = jump_len.try_into().unwrap();
1206                    this.reader.as_mut().consume(jump_len_usize);
1207
1208                    // Update the state based on how much more data
1209                    // we have left to decompress
1210                    *this.pending_seek = Some(PendingSeek {
1211                        state: PendingSeekState::JumpingForward {
1212                            target_pos,
1213                            decompress_len: decompress_len - jump_len,
1214                        },
1215                        ..pending_seek
1216                    });
1217                }
1218                PendingSeekState::RestoringSeekToFrame { .. }
1219                | PendingSeekState::RestoringJumpForward { .. } => {
1220                    // The seek was cancelled, so poll until the
1221                    // cancellation finished then return an error
1222                    ready!(self.as_mut().poll_cancel_seek_tokio(cx))?;
1223                    return std::task::Poll::Ready(Err(std::io::Error::other("seek cancelled")));
1224                }
1225            }
1226        }
1227    }
1228}
1229
1230#[cfg(feature = "futures")]
1231impl<R> futures::AsyncBufRead for AsyncZstdSeekableReader<'_, R>
1232where
1233    R: futures::AsyncBufRead + futures::AsyncSeek,
1234{
1235    fn poll_fill_buf(
1236        mut self: std::pin::Pin<&mut Self>,
1237        cx: &mut std::task::Context<'_>,
1238    ) -> std::task::Poll<std::io::Result<&[u8]>> {
1239        // Cancel any in-progress seeks
1240        ready!(self.as_mut().poll_cancel_seek_futures(cx))?;
1241
1242        // Defer to the underlying implementation
1243        let this = self.project();
1244        this.reader.poll_fill_buf(cx)
1245    }
1246
1247    fn consume(self: std::pin::Pin<&mut Self>, amt: usize) {
1248        let this = self.project();
1249
1250        // Ensure we aren't trying to seek before removing any
1251        // data from the buffer
1252        assert!(
1253            this.pending_seek.is_none(),
1254            "tried to consume from buffer while seeking"
1255        );
1256
1257        // Defer to the underlying implementation
1258        this.reader.consume(amt);
1259    }
1260}
1261
1262#[cfg(feature = "futures")]
1263impl<R> futures::AsyncRead for AsyncZstdSeekableReader<'_, R>
1264where
1265    R: futures::AsyncBufRead + futures::AsyncSeek,
1266{
1267    fn poll_read(
1268        mut self: std::pin::Pin<&mut Self>,
1269        cx: &mut std::task::Context<'_>,
1270        buf: &mut [u8],
1271    ) -> std::task::Poll<std::io::Result<usize>> {
1272        use futures::AsyncBufRead as _;
1273
1274        // Decode some data from the underlying reader
1275        let filled = ready!(self.as_mut().poll_fill_buf(cx))?;
1276
1277        // Get some of the decoded data, capped to `buf`'s length
1278        let consumable = filled.len().min(buf.len());
1279
1280        // Copy the decoded data to `buf`
1281        buf[..consumable].copy_from_slice(&filled[..consumable]);
1282
1283        // Consume the copied data
1284        self.consume(consumable);
1285
1286        std::task::Poll::Ready(Ok(consumable))
1287    }
1288}
1289
1290#[cfg(feature = "futures")]
1291impl<R> futures::AsyncSeek for AsyncZstdSeekableReader<'_, R>
1292where
1293    R: futures::AsyncBufRead + futures::AsyncSeek,
1294{
1295    fn poll_seek(
1296        mut self: std::pin::Pin<&mut Self>,
1297        cx: &mut std::task::Context<'_>,
1298        position: std::io::SeekFrom,
1299    ) -> std::task::Poll<std::io::Result<u64>> {
1300        use crate::buffer::Buffer as _;
1301        use futures::io::AsyncBufRead as _;
1302
1303        // Iterate until the seek is complete. Each iteration should
1304        // make some progress based on the current pending seek state
1305        // (or should return `Poll::Pending`)
1306        loop {
1307            let this = self.as_mut().project();
1308
1309            let pending_seek = match *this.pending_seek {
1310                Some(pending_seek) if pending_seek.seek == position => {
1311                    // We're already seeking with the same position,
1312                    // so keep going from where we left off
1313                    pending_seek
1314                }
1315                _ => {
1316                    // If there's another seek operation in progress,
1317                    // cancel it first
1318                    ready!(self.as_mut().poll_cancel_seek_futures(cx))?;
1319
1320                    let this = self.as_mut().project();
1321
1322                    // Start a new seek operation
1323                    let pending_seek = PendingSeek {
1324                        initial_pos: this.reader.current_pos,
1325                        seek: position,
1326                        state: PendingSeekState::Starting,
1327                    };
1328                    *this.pending_seek = Some(pending_seek);
1329                    pending_seek
1330                }
1331            };
1332
1333            let mut this = self.as_mut().project();
1334
1335            match pending_seek.state {
1336                PendingSeekState::Starting => {
1337                    // Seek is starting. The first step is to determine
1338                    // the seek target relative to the start of the stream
1339
1340                    match pending_seek.seek {
1341                        std::io::SeekFrom::Start(offset) => {
1342                            // The offset is already relatve to the start,
1343                            // so transition to the state to start seeking
1344                            *this.pending_seek = Some(PendingSeek {
1345                                state: PendingSeekState::SeekingToTarget { target_pos: offset },
1346                                ..pending_seek
1347                            });
1348                        }
1349                        std::io::SeekFrom::End(end_offset) => {
1350                            // To determine the offset relative to the start,
1351                            // we first need to reach the end of the stream.
1352
1353                            // Determine the best way to reach the last
1354                            // position we know about in the stream.
1355                            let seek = this.reader.decoder.prepare_seek_to_last_known_pos();
1356
1357                            if let Some(frame) = seek.seek_to_frame_start {
1358                                // We have a frame to seek to, so
1359                                // transition to the state to seek before
1360                                // jumping to the end of the stream
1361                                *this.pending_seek = Some(PendingSeek {
1362                                    state: PendingSeekState::SeekingToLastFrame {
1363                                        frame,
1364                                        end_offset,
1365                                    },
1366                                    ..pending_seek
1367                                })
1368                            } else {
1369                                // No need to seek, so transition to the
1370                                // state to jump to the end of the stream
1371                                *this.pending_seek = Some(PendingSeek {
1372                                    state: PendingSeekState::JumpingToEnd { end_offset },
1373                                    ..pending_seek
1374                                });
1375                            }
1376                        }
1377                        std::io::SeekFrom::Current(offset) => {
1378                            // Compute the offset relative to the current position
1379                            let offset = this.reader.current_pos.checked_add_signed(offset);
1380                            let offset = match offset {
1381                                Some(offset) => offset,
1382                                None => {
1383                                    // Offset overflowed, so clear the pending
1384                                    // seek and return an error
1385                                    *this.pending_seek = None;
1386                                    return std::task::Poll::Ready(Err(std::io::Error::other(
1387                                        "invalid seek offset",
1388                                    )));
1389                                }
1390                            };
1391
1392                            // Transition to the state to start seeking based
1393                            // on the computed offset
1394                            *this.pending_seek = Some(PendingSeek {
1395                                state: PendingSeekState::SeekingToTarget { target_pos: offset },
1396                                ..pending_seek
1397                            });
1398                        }
1399                    }
1400                }
1401                PendingSeekState::SeekingToLastFrame { end_offset, frame } => {
1402                    // We're seeking to the last (known) frame in the stream
1403                    // before jumping to the end
1404
1405                    let reader = this.reader.as_mut().project();
1406
1407                    // Seek the underlying reader
1408                    let result = ready!(reader
1409                        .reader
1410                        .poll_seek(cx, std::io::SeekFrom::Start(frame.compressed_pos)));
1411                    match result {
1412                        Ok(_) => {}
1413                        Err(error) => {
1414                            // Seeking the underlying reader failed,
1415                            // so clear the in-progress seek and bail
1416                            *this.pending_seek = None;
1417                            return std::task::Poll::Ready(Err(error));
1418                        }
1419                    };
1420
1421                    // Update the decoder based on what frame we're now at
1422                    let result = reader.decoder.seeked_to_frame(frame);
1423                    match result {
1424                        Ok(_) => {}
1425                        Err(error) => {
1426                            // The decoder failed, so clear the
1427                            // in-progress seek and bail
1428                            *this.pending_seek = None;
1429                            return std::task::Poll::Ready(Err(error));
1430                        }
1431                    }
1432
1433                    // Update our internal position to align with the
1434                    // start of the frame
1435                    *reader.current_pos = frame.decompressed_pos;
1436
1437                    // Clear the buffer
1438                    reader.buffer.clear();
1439
1440                    // Seek complete, so transition states to jump to
1441                    // the end of the stream
1442                    *this.pending_seek = Some(PendingSeek {
1443                        state: PendingSeekState::JumpingToEnd { end_offset },
1444                        ..pending_seek
1445                    });
1446                }
1447                PendingSeekState::JumpingToEnd { end_offset } => {
1448                    // Seek target is relative to the end of the stream, so
1449                    // we need to jump to the end of the stream before
1450                    // determing the target position
1451
1452                    // Try to jump to the end of stream
1453                    let result = ready!(this.reader.as_mut().poll_jump_to_end_futures(cx));
1454                    match result {
1455                        Ok(_) => {}
1456                        Err(error) => {
1457                            // Jumping to the end failed, so cancel the
1458                            // in-progress seek and bail
1459                            *this.pending_seek = None;
1460                            return std::task::Poll::Ready(Err(error));
1461                        }
1462                    };
1463
1464                    // Now we're at the end of the stream, so we can now
1465                    // compute the target position
1466                    let target_pos = this.reader.current_pos.checked_add_signed(end_offset);
1467                    let target_pos = match target_pos {
1468                        Some(target_pos) => target_pos,
1469                        None => {
1470                            // Target position overflowed, so cancel the
1471                            // in-progress seek and bail
1472                            *this.pending_seek = None;
1473                            return std::task::Poll::Ready(Err(std::io::Error::other(
1474                                "invalid seek offset",
1475                            )));
1476                        }
1477                    };
1478
1479                    // Transition to the state to start seeking based
1480                    // on the computed offset
1481                    *this.pending_seek = Some(PendingSeek {
1482                        state: PendingSeekState::SeekingToTarget { target_pos },
1483                        ..pending_seek
1484                    });
1485                }
1486                PendingSeekState::SeekingToTarget { target_pos } => {
1487                    // We now know the relative position we're seeking to
1488
1489                    // Consume any leftover data in `self.buffer`. This ensures
1490                    // that the current position is in line with the
1491                    // underlying decoder
1492                    let consumable = this.reader.buffer.uncommitted().len();
1493                    this.reader.as_mut().consume(consumable);
1494
1495                    // Determine what we need to do to reach the target position
1496                    let seek = this
1497                        .reader
1498                        .decoder
1499                        .prepare_seek_to_decompressed_pos(target_pos);
1500
1501                    if let Some(frame) = seek.seek_to_frame_start {
1502                        // We need to seek to the start of a frame
1503
1504                        // Transition to the state so we poll until
1505                        // the seek completes
1506                        *this.pending_seek = Some(PendingSeek {
1507                            state: PendingSeekState::SeekingToFrame {
1508                                target_pos,
1509                                frame,
1510                                decompress_len: seek.decompress_len,
1511                            },
1512                            ..pending_seek
1513                        });
1514                    } else {
1515                        // We need to keep decoding bytes to reach the
1516                        // target position
1517
1518                        // Transition to the state so that we can keep
1519                        // decompressing until reaching the target position
1520                        *this.pending_seek = Some(PendingSeek {
1521                            state: PendingSeekState::JumpingForward {
1522                                target_pos,
1523                                decompress_len: seek.decompress_len,
1524                            },
1525                            ..pending_seek
1526                        });
1527                    }
1528                }
1529                PendingSeekState::SeekingToFrame {
1530                    target_pos,
1531                    frame,
1532                    decompress_len,
1533                } => {
1534                    // We're seeking to the start of a frame
1535
1536                    let reader = this.reader.as_mut().project();
1537
1538                    // Seek the underlying reader
1539                    let result = ready!(reader
1540                        .reader
1541                        .poll_seek(cx, std::io::SeekFrom::Start(frame.compressed_pos)));
1542                    match result {
1543                        Ok(_) => {}
1544                        Err(error) => {
1545                            // Seeking the underlying reader failed,
1546                            // so clear the in-progress seek and bail
1547                            *this.pending_seek = None;
1548                            return std::task::Poll::Ready(Err(error));
1549                        }
1550                    };
1551
1552                    // Update the decoder based on what frame we're now at
1553                    let result = reader.decoder.seeked_to_frame(frame);
1554                    match result {
1555                        Ok(_) => {}
1556                        Err(error) => {
1557                            // The decoder failed, so clear the
1558                            // in-progress seek and bail
1559                            *this.pending_seek = None;
1560                            return std::task::Poll::Ready(Err(error));
1561                        }
1562                    }
1563
1564                    // Update our internal position to align with the
1565                    // start of the frame
1566                    *reader.current_pos = frame.decompressed_pos;
1567
1568                    // Seek complete, so transition states to decompress
1569                    // until reaching the target position
1570                    *this.pending_seek = Some(PendingSeek {
1571                        state: PendingSeekState::JumpingForward {
1572                            target_pos,
1573                            decompress_len,
1574                        },
1575                        ..pending_seek
1576                    });
1577                }
1578                PendingSeekState::JumpingForward {
1579                    target_pos,
1580                    decompress_len: 0,
1581                } => {
1582                    // No more bytes to decompress, so at the target position!
1583                    // Clear the pending seek, then we're done
1584                    assert_eq!(target_pos, this.reader.current_pos);
1585                    *this.pending_seek = None;
1586                    return std::task::Poll::Ready(Ok(this.reader.current_pos));
1587                }
1588                PendingSeekState::JumpingForward {
1589                    target_pos,
1590                    decompress_len,
1591                } => {
1592                    // We have some bytes to decompress before reaching
1593                    // the target position
1594
1595                    // Try to decompress some data from the underlying
1596                    // reader
1597                    let result = ready!(this.reader.as_mut().poll_fill_buf(cx));
1598                    let filled = match result {
1599                        Ok(filled) => filled,
1600                        Err(error) => {
1601                            // Failed to decompress from the underlying
1602                            // reader, so clear the pending seek and bail
1603                            *this.pending_seek = None;
1604                            return std::task::Poll::Ready(Err(error));
1605                        }
1606                    };
1607
1608                    if filled.is_empty() {
1609                        // The underlying reader didn't give us any data, which
1610                        // means we hit EOF while trying to seek. Clear the
1611                        // pending seek and bail
1612                        *this.pending_seek = None;
1613                        return std::task::Poll::Ready(Err(std::io::Error::new(
1614                            std::io::ErrorKind::UnexpectedEof,
1615                            "reached eof while trying to decode to offset",
1616                        )));
1617                    }
1618
1619                    // Consume as much data as we can to try and reach
1620                    // the total data to decompress
1621                    let filled_len_u64: u64 = filled.len().try_into().unwrap();
1622                    let jump_len = filled_len_u64.min(decompress_len);
1623                    let jump_len_usize: usize = jump_len.try_into().unwrap();
1624                    this.reader.as_mut().consume(jump_len_usize);
1625
1626                    // Update the state based on how much more data
1627                    // we have left to decompress
1628                    *this.pending_seek = Some(PendingSeek {
1629                        state: PendingSeekState::JumpingForward {
1630                            target_pos,
1631                            decompress_len: decompress_len - jump_len,
1632                        },
1633                        ..pending_seek
1634                    });
1635                }
1636                PendingSeekState::RestoringSeekToFrame { .. }
1637                | PendingSeekState::RestoringJumpForward { .. } => {
1638                    // The seek was cancelled, so poll until the
1639                    // cancellation finished then return an error
1640                    ready!(self.as_mut().poll_cancel_seek_futures(cx))?;
1641                    return std::task::Poll::Ready(Err(std::io::Error::other("seek cancelled")));
1642                }
1643            }
1644        }
1645    }
1646}
1647
1648/// A builder that builds an [`AsyncZstdReader`] from the provided reader.
1649pub struct ZstdReaderBuilder<R> {
1650    reader: R,
1651    table: ZstdSeekTable,
1652}
1653
1654#[cfg(feature = "tokio")]
1655impl<R> ZstdReaderBuilder<tokio::io::BufReader<R>> {
1656    fn new_tokio(reader: R) -> Self
1657    where
1658        R: tokio::io::AsyncRead,
1659    {
1660        let reader = tokio::io::BufReader::with_capacity(zstd::zstd_safe::DCtx::in_size(), reader);
1661        ZstdReaderBuilder::with_buffered(reader)
1662    }
1663}
1664
1665#[cfg(feature = "futures")]
1666impl<R> ZstdReaderBuilder<futures::io::BufReader<R>> {
1667    fn new_futures(reader: R) -> Self
1668    where
1669        R: futures::AsyncRead,
1670    {
1671        let reader =
1672            futures::io::BufReader::with_capacity(zstd::zstd_safe::DCtx::in_size(), reader);
1673        ZstdReaderBuilder::with_buffered(reader)
1674    }
1675}
1676
1677impl<R> ZstdReaderBuilder<R> {
1678    fn with_buffered(reader: R) -> Self {
1679        ZstdReaderBuilder {
1680            reader,
1681            table: ZstdSeekTable::empty(),
1682        }
1683    }
1684
1685    /// Use the given seek table when seeking the resulting reader. This can
1686    /// greatly speed up seek operations when using a zstd stream that
1687    /// uses the [zstd seekable format].
1688    ///
1689    /// See [`crate::table::read_seek_table`] for reading a seek table.
1690    ///
1691    /// [zstd seekable format]: https://github.com/facebook/zstd/tree/51eb7daf39c8e8a7c338ba214a9d4e2a6a086826/contrib/seekable_format
1692    pub fn with_seek_table(mut self, table: ZstdSeekTable) -> Self {
1693        self.table = table;
1694        self
1695    }
1696
1697    /// Build the reader.
1698    pub fn build(self) -> std::io::Result<AsyncZstdReader<'static, R>> {
1699        let zstd_decoder = zstd::stream::raw::Decoder::new()?;
1700        let buffer = crate::buffer::FixedBuffer::new(vec![0; zstd::zstd_safe::DCtx::out_size()]);
1701        let decoder = ZstdFramedDecoder::new(zstd_decoder, self.table);
1702
1703        Ok(AsyncZstdReader {
1704            reader: self.reader,
1705            decoder,
1706            buffer,
1707            current_pos: 0,
1708        })
1709    }
1710}
1711
1712#[cfg_attr(not(any(feature = "tokio", feature = "futures")), expect(dead_code))]
1713#[derive(Debug, Clone, Copy)]
1714struct PendingSeek {
1715    initial_pos: u64,
1716    seek: std::io::SeekFrom,
1717    state: PendingSeekState,
1718}
1719
1720#[cfg_attr(not(any(feature = "tokio", feature = "futures")), expect(dead_code))]
1721#[derive(Debug, Clone, Copy)]
1722enum PendingSeekState {
1723    Starting,
1724    SeekingToLastFrame {
1725        end_offset: i64,
1726        frame: ZstdFrame,
1727    },
1728    JumpingToEnd {
1729        end_offset: i64,
1730    },
1731    SeekingToTarget {
1732        target_pos: u64,
1733    },
1734    SeekingToFrame {
1735        target_pos: u64,
1736        frame: ZstdFrame,
1737        decompress_len: u64,
1738    },
1739    JumpingForward {
1740        target_pos: u64,
1741        decompress_len: u64,
1742    },
1743    RestoringSeekToFrame {
1744        frame: ZstdFrame,
1745        decompress_len: u64,
1746    },
1747    RestoringJumpForward {
1748        decompress_len: u64,
1749    },
1750}