zstd_framed/async_reader.rs
1use crate::{
2 decoder::ZstdFramedDecoder,
3 table::{ZstdFrame, ZstdSeekTable},
4};
5
6pin_project_lite::pin_project! {
7 /// An reader that decompresses a zstd stream from an underlying
8 /// async reader. Works as either a `tokio` or `futures` reader if
9 /// their respective features are enabled.
10 ///
11 /// The underyling reader `R` should implement the following traits:
12 ///
13 /// - `tokio`
14 /// - [`tokio::io::AsyncBufRead`] (required for [`tokio::io::AsyncRead`] and [`tokio::io::AsyncBufRead`] impls)
15 /// - (Optional) [`tokio::io::AsyncSeek`] (used when calling [`.seekable()`](AsyncZstdReader::seekable))
16 /// - `futures`
17 /// - [`futures::AsyncBufRead`] (required for [`futures::AsyncRead`] and [`futures::AsyncBufRead`] impls)
18 /// - (Optional) [`futures::AsyncSeek`] (used when calling [`.seekable()`](AsyncZstdReader::seekable))
19 ///
20 /// For sync I/O support, see [`crate::ZstdReader`].
21 ///
22 /// ## Construction
23 ///
24 /// Create a builder using [`AsyncZstdReader::builder_tokio`] (recommended
25 /// for `tokio`) or [`AsyncZstdReader::builder_futures`] (recommended for
26 /// `futures`); or use [`ZstdReader::builder_buffered`] to use a custom
27 /// buffer for either. See [`ZstdReaderBuilder`] for build options. Call
28 /// [`AsyncZstdReaderBuilder::build`] to build the
29 /// [`AsyncZstdReader`] instance.
30 ///
31 /// ```
32 /// # #[cfg(feature = "tokio")]
33 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
34 /// # let compressed_file: &[u8] = &[];
35 /// // Tokio example
36 /// let reader = zstd_framed::AsyncZstdReader::builder_tokio(compressed_file)
37 /// // .with_seek_table(table) // Provide a seek table if available
38 /// .build()?;
39 /// # Ok(())
40 /// # }
41 /// # #[cfg(not(feature = "tokio"))]
42 /// # fn main() { }
43 /// ```
44 ///
45 /// ```
46 /// # #[cfg(feature = "futures")]
47 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
48 /// # let compressed_file: &[u8] = &[];
49 /// // futures example
50 /// let reader = zstd_framed::AsyncZstdReader::builder_futures(compressed_file)
51 /// // .with_seek_table(table) // Provide a seek table if available
52 /// .build()?;
53 /// # Ok(())
54 /// # }
55 /// # #[cfg(not(feature = "futures"))]
56 /// # fn main() { }
57 /// ```
58 ///
59 /// ## Buffering
60 ///
61 /// The decompressed zstd output is always buffered internally. Since the
62 /// reader must also implement [`tokio::io::AsyncBufRead`] /
63 /// [`futures::AsyncBufRead`], the compressed input must also be buffered.
64 ///
65 /// [`AsyncZstdReader::builder_tokio`] and
66 /// [`AsyncZstdReader::builder_futures`] will wrap any reader implmenting
67 /// [`tokio::io::AsyncRead`] or [`futures::AsyncRead`] (respectively)
68 /// with a recommended buffer size for the input stream. For more control
69 /// over how the input gets buffered, you can instead use
70 /// [`AsyncZstdReader::builder_buffered`].
71 ///
72 /// ## Seeking
73 ///
74 /// If the underlying reader is seekable (i.e. it implements either
75 /// [`tokio::io::AsyncSeek`] or [`futures::AsyncSeek`]), you can call
76 /// [`.seekable()`](`AsyncZstdReader::seekable`) to convert it to a seekable reader. See
77 /// [`AsyncZstdSeekableReader`] for notes and caveats about seeking.
78 pub struct AsyncZstdReader<'dict, R> {
79 #[pin]
80 reader: R,
81 decoder: ZstdFramedDecoder<'dict>,
82 buffer: crate::buffer::FixedBuffer<Vec<u8>>,
83 current_pos: u64,
84 }
85}
86
87impl<'dict, R> AsyncZstdReader<'dict, R> {
88 /// Create a new zstd reader that decompresses the zstd stream from
89 /// the underlying Tokio reader. The provided reader will be wrapped
90 /// with an appropriately-sized buffer.
91 #[cfg(feature = "tokio")]
92 pub fn builder_tokio(reader: R) -> ZstdReaderBuilder<tokio::io::BufReader<R>>
93 where
94 R: tokio::io::AsyncRead,
95 {
96 ZstdReaderBuilder::new_tokio(reader)
97 }
98
99 /// Create a new zstd reader that decompresses the zstd stream from
100 /// the underlying `futures` reader. The provided reader will be
101 /// wrapped with an appropriately-sized buffer.
102 #[cfg(feature = "futures")]
103 pub fn builder_futures(reader: R) -> ZstdReaderBuilder<futures::io::BufReader<R>>
104 where
105 R: futures::AsyncRead,
106 {
107 ZstdReaderBuilder::new_futures(reader)
108 }
109
110 /// Create a new zstd reader that decompresses the zstd stream from
111 /// the underlying reader. The underlying reader must implement
112 /// either [`tokio::io::AsyncBufRead`] or [`futures::AsyncBufRead`],
113 /// and its buffer will be used directly. When in doubt, use
114 /// one of the other builder methods to use an appropriate buffer size
115 /// for decompressing a zstd stream.
116 pub fn builder_buffered(reader: R) -> ZstdReaderBuilder<R> {
117 ZstdReaderBuilder::with_buffered(reader)
118 }
119
120 /// Wrap the reader with [`AsyncZstdSeekableReader`], which adds support
121 /// for seeking if the underlying reader supports seeking.
122 pub fn seekable(self) -> AsyncZstdSeekableReader<'dict, R> {
123 AsyncZstdSeekableReader {
124 reader: self,
125 pending_seek: None,
126 }
127 }
128
129 /// Decode and consume the entire zstd stream until reaching the end.
130 /// Stops when reaching EOF (i.e. the underlying reader had no more data)
131 #[cfg(feature = "tokio")]
132 fn poll_jump_to_end_tokio(
133 mut self: std::pin::Pin<&mut Self>,
134 cx: &mut std::task::Context<'_>,
135 ) -> std::task::Poll<std::io::Result<()>>
136 where
137 R: tokio::io::AsyncBufRead,
138 {
139 use tokio::io::AsyncBufRead as _;
140
141 loop {
142 // Decode some data from the underlying reader
143 let decoded = ready!(self.as_mut().poll_fill_buf(cx))?;
144
145 // If we didn't get any more data, we're done
146 if decoded.is_empty() {
147 break;
148 }
149
150 // Consume all the decoded data
151 let decoded_len = decoded.len();
152 self.as_mut().consume(decoded_len);
153 }
154
155 std::task::Poll::Ready(Ok(()))
156 }
157
158 /// Decode and consume the entire zstd stream until reaching the end.
159 /// Stops when reaching EOF (i.e. the underlying reader had no more data)
160 #[cfg(feature = "futures")]
161 fn poll_jump_to_end_futures(
162 mut self: std::pin::Pin<&mut Self>,
163 cx: &mut std::task::Context<'_>,
164 ) -> std::task::Poll<std::io::Result<()>>
165 where
166 R: futures::AsyncBufRead,
167 {
168 use futures::AsyncBufRead as _;
169
170 loop {
171 // Decode some data from the underlying reader
172 let decoded = ready!(self.as_mut().poll_fill_buf(cx))?;
173
174 // If we didn't get any more data, we're done
175 if decoded.is_empty() {
176 break;
177 }
178
179 // Consume all the decoded data
180 let decoded_len = decoded.len();
181 self.as_mut().consume(decoded_len);
182 }
183
184 std::task::Poll::Ready(Ok(()))
185 }
186}
187
188#[cfg(feature = "tokio")]
189impl<R> tokio::io::AsyncBufRead for AsyncZstdReader<'_, R>
190where
191 R: tokio::io::AsyncBufRead,
192{
193 fn poll_fill_buf(
194 mut self: std::pin::Pin<&mut Self>,
195 cx: &mut std::task::Context<'_>,
196 ) -> std::task::Poll<std::io::Result<&[u8]>> {
197 use crate::buffer::Buffer as _;
198
199 loop {
200 let mut this = self.as_mut().project();
201
202 // Check if our buffer contais any data we can return
203 if !this.buffer.uncommitted().is_empty() {
204 // If it does, we're done
205 break;
206 }
207
208 // Get some data from the underlying reader
209 let decodable = ready!(this.reader.as_mut().poll_fill_buf(cx))?;
210 if decodable.is_empty() {
211 // If the underlying reader doesn't have any more data,
212 // then we're done
213 break;
214 }
215
216 // Decode the data, and write it to `self.buffer`
217 let consumed = this.decoder.decode(decodable, this.buffer)?;
218
219 // Tell the underlying reader that we read the subset of
220 // data we decoded
221 this.reader.consume(consumed);
222 }
223
224 // Return all the data we have in `self.buffer`
225 std::task::Poll::Ready(Ok(self.project().buffer.uncommitted()))
226 }
227
228 fn consume(self: std::pin::Pin<&mut Self>, amt: usize) {
229 use crate::buffer::Buffer as _;
230
231 let this = self.project();
232
233 // Tell the buffer that we've committed the data that was consumed
234 this.buffer.commit(amt);
235
236 // Advance the reader's position
237 let amt_u64: u64 = amt.try_into().unwrap();
238 *this.current_pos += amt_u64;
239 }
240}
241
242#[cfg(feature = "tokio")]
243impl<R> tokio::io::AsyncRead for AsyncZstdReader<'_, R>
244where
245 R: tokio::io::AsyncBufRead,
246{
247 fn poll_read(
248 mut self: std::pin::Pin<&mut Self>,
249 cx: &mut std::task::Context<'_>,
250 buf: &mut tokio::io::ReadBuf<'_>,
251 ) -> std::task::Poll<std::io::Result<()>> {
252 use tokio::io::AsyncBufRead as _;
253
254 // Decode some data from the underlying reader
255 let filled = ready!(self.as_mut().poll_fill_buf(cx))?;
256
257 // Get some of the decoded data, capped to `buf`'s length
258 let consumable = filled.len().min(buf.remaining());
259
260 // Copy the decoded data to `buf`
261 buf.put_slice(&filled[..consumable]);
262
263 // Consume the copied data
264 self.consume(consumable);
265
266 std::task::Poll::Ready(Ok(()))
267 }
268}
269
270#[cfg(feature = "futures")]
271impl<R> futures::AsyncBufRead for AsyncZstdReader<'_, R>
272where
273 R: futures::io::AsyncBufRead,
274{
275 fn poll_fill_buf(
276 mut self: std::pin::Pin<&mut Self>,
277 cx: &mut std::task::Context<'_>,
278 ) -> std::task::Poll<std::io::Result<&[u8]>> {
279 use crate::buffer::Buffer as _;
280
281 loop {
282 let mut this = self.as_mut().project();
283
284 // Check if our buffer contais any data we can return
285 if !this.buffer.uncommitted().is_empty() {
286 // If it does, we're done
287 break;
288 }
289
290 // Get some data from the underlying reader
291 let decodable = ready!(this.reader.as_mut().poll_fill_buf(cx))?;
292 if decodable.is_empty() {
293 // If the underlying reader doesn't have any more data,
294 // then we're done
295 break;
296 }
297
298 // Decode the data, and write it to `self.buffer`
299 let consumed = this.decoder.decode(decodable, this.buffer)?;
300
301 // Tell the underlying reader that we read the subset of
302 // data we decoded
303 this.reader.consume(consumed);
304 }
305
306 // Return all the data we have in `self.buffer`
307 std::task::Poll::Ready(Ok(self.project().buffer.uncommitted()))
308 }
309
310 fn consume(self: std::pin::Pin<&mut Self>, amt: usize) {
311 use crate::buffer::Buffer as _;
312
313 let this = self.project();
314
315 // Tell the buffer that we've committed the data that was consumed
316 this.buffer.commit(amt);
317
318 // Advance the reader's position
319 let amt_u64: u64 = amt.try_into().unwrap();
320 *this.current_pos += amt_u64;
321 }
322}
323
324#[cfg(feature = "futures")]
325impl<R> futures::AsyncRead for AsyncZstdReader<'_, R>
326where
327 R: futures::AsyncBufRead,
328{
329 fn poll_read(
330 mut self: std::pin::Pin<&mut Self>,
331 cx: &mut std::task::Context<'_>,
332 buf: &mut [u8],
333 ) -> std::task::Poll<std::io::Result<usize>> {
334 use futures::AsyncBufRead as _;
335
336 if buf.is_empty() {
337 return std::task::Poll::Ready(Ok(0));
338 }
339
340 // Decode some data from the underlying reader
341 let filled = ready!(self.as_mut().poll_fill_buf(cx))?;
342
343 // Get some of the decoded data, capped to `buf`'s length
344 let consumable = filled.len().min(buf.len());
345
346 // Copy the decoded data to `buf`
347 buf[..consumable].copy_from_slice(&filled[..consumable]);
348
349 // Consume the copied data
350 self.consume(consumable);
351
352 std::task::Poll::Ready(Ok(consumable))
353 }
354}
355
356pin_project_lite::pin_project! {
357 /// A wrapper around [`AsyncZstdReader`] with extra support for seeking.
358 /// Created via the method [`AsyncZstdReader::seekable`].
359 ///
360 /// The underlying reader `R` should implement the following traits:
361 ///
362 /// - `tokio`
363 /// - [`tokio::io::AsyncBufRead`] + [`tokio::io::AsyncSeek`] (required for [`tokio::io::AsyncRead`], [`tokio::io::AsyncBufRead`], and [`tokio::io::AsyncSeek`] impls)
364 /// - `futures`
365 /// - [`futures::AsyncBufRead`] + [`futures::AsyncSeek`] (required for [`futures::AsyncRead`], [`futures::AsyncBufRead`], and [`tokio::io::AsyncSeek`] impls)
366 ///
367 /// **By default, seeking
368 /// within the stream will linearly decompress
369 /// until reaching the target!**
370 ///
371 /// Seeking can do a lot better when the underlying stream is broken up
372 /// into multiple frames, such as a stream that uses the [zstd seekable format].
373 /// You can create such a stream using [`ZstdWriterBuilder::with_seek_table`](crate::async_writer::ZstdWriterBuilder::with_seek_table).
374 ///
375 /// There are two situations where seeking can take advantage of a seek
376 /// table:
377 ///
378 /// 1. When a seek table is provided up-front using [`ZstdReaderBuilder::with_seek_table`].
379 /// See [`crate::table::read_seek_table`] for reading a seek table
380 /// from a reader (there are also async-friendly functions available).
381 /// 2. When rewinding to a previously-decompressed frame. Frame offsets are
382 /// automatically recorded during decompression.
383 ///
384 /// Even if a seek table is used, seeking will still need to rewind to
385 /// the start of a frame, then decompress until reaching the target offset.
386 ///
387 /// [zstd seekable format]: https://github.com/facebook/zstd/tree/51eb7daf39c8e8a7c338ba214a9d4e2a6a086826/contrib/seekable_format
388 pub struct AsyncZstdSeekableReader<'dict, R> {
389 #[pin]
390 reader: AsyncZstdReader<'dict, R>,
391 pending_seek: Option<PendingSeek>,
392 }
393}
394
395impl<R> AsyncZstdSeekableReader<'_, R> {
396 /// If a seek operation was started with [`tokio::io::AsyncSeek::start_seek`]
397 /// but wasn't polled to completion, "undo" the seek by seeking
398 /// back to where we were in the zstd stream.
399 #[cfg(feature = "tokio")]
400 fn poll_cancel_seek_tokio(
401 self: std::pin::Pin<&mut Self>,
402 cx: &mut std::task::Context<'_>,
403 ) -> std::task::Poll<std::io::Result<()>>
404 where
405 R: tokio::io::AsyncBufRead + tokio::io::AsyncSeek,
406 {
407 use crate::buffer::Buffer as _;
408 use tokio::io::AsyncBufRead as _;
409
410 let mut this = self.project();
411
412 // Iterate until `self.pending_seek` is unset. Each iteration
413 // should make some progress based on the current pending seek
414 // state (or should return `Poll::Pending`)
415 loop {
416 let Some(pending_seek) = *this.pending_seek else {
417 // No pending seek, which means we're done!
418 return std::task::Poll::Ready(Ok(()));
419 };
420
421 match pending_seek.state {
422 PendingSeekState::Starting => {
423 // Seek just started. There's nothing to undo, so
424 // just clear the pending seek
425 *this.pending_seek = None;
426 }
427 PendingSeekState::SeekingToLastFrame { .. }
428 | PendingSeekState::JumpingToEnd { .. }
429 | PendingSeekState::SeekingToTarget { .. }
430 | PendingSeekState::SeekingToFrame { .. }
431 | PendingSeekState::JumpingForward { .. } => {
432 // Seek is in progress
433
434 // Consume any leftover data in `self.buffer`. This ensures
435 // that the current position is in line with the
436 // underlying decoder
437 let consumable = this.reader.buffer.uncommitted().len();
438 this.reader.as_mut().consume(consumable);
439
440 // Determine what we need to do to reach the target position
441 let seek = this
442 .reader
443 .decoder
444 .prepare_seek_to_decompressed_pos(pending_seek.initial_pos);
445
446 if let Some(frame) = seek.seek_to_frame_start {
447 // We need to seek to the start of a frame
448
449 // Transition the state to indicate we're seeking
450 // to the start of a frame
451 *this.pending_seek = Some(PendingSeek {
452 state: PendingSeekState::RestoringSeekToFrame {
453 frame,
454 decompress_len: seek.decompress_len,
455 },
456 ..pending_seek
457 });
458
459 // Submit a seek job to the underlying reader
460 let reader = this.reader.as_mut().project().reader;
461 let result =
462 reader.start_seek(std::io::SeekFrom::Start(frame.compressed_pos));
463
464 match result {
465 Ok(_) => {}
466 Err(error) => {
467 // Trying to seek the underlying reader
468 // failed, so clear the pending seek and bail
469 *this.pending_seek = None;
470 return std::task::Poll::Ready(Err(std::io::Error::other(
471 format!("failed to cancel in-progress zstd seek: {error}"),
472 )));
473 }
474 }
475 } else {
476 // We just need to keep decompressing to reach the
477 // target position
478
479 // Transition to a state to indicate how many
480 // bytes we need to consume
481 *this.pending_seek = Some(PendingSeek {
482 state: PendingSeekState::RestoringJumpForward {
483 decompress_len: seek.decompress_len,
484 },
485 ..pending_seek
486 });
487 }
488 }
489 PendingSeekState::RestoringSeekToFrame {
490 frame,
491 decompress_len,
492 } => {
493 // We're in the process of restoring to the previous
494 // seek position, and need to seek to the start of
495 // a frame
496
497 let reader = this.reader.as_mut().project();
498
499 // Poll until the seek completes
500 let result = ready!(reader.reader.poll_complete(cx));
501
502 match result {
503 Ok(_) => {}
504 Err(error) => {
505 // Seeking the underlying reader failed, so
506 // clear the pending seek and bail
507 *this.pending_seek = None;
508 return std::task::Poll::Ready(Err(std::io::Error::other(format!(
509 "failed to cancel in-progress zstd seek: {error}"
510 ))));
511 }
512 }
513
514 // Update our internal position to align with the start of the frame
515 *reader.current_pos = frame.decompressed_pos;
516
517 // Update the decoder based on what frame we're now at
518 let result = reader.decoder.seeked_to_frame(frame);
519 match result {
520 Ok(_) => {}
521 Err(error) => {
522 // Error from the decoder, so clear the pending
523 // seek and bail
524 *this.pending_seek = None;
525 return std::task::Poll::Ready(Err(std::io::Error::other(format!(
526 "failed to cancel in-progress zstd seek: {error}"
527 ))));
528 }
529 }
530
531 // Update our state to decompress as much data as we
532 // need to reach the initial position again
533 *this.pending_seek = Some(PendingSeek {
534 state: PendingSeekState::RestoringJumpForward { decompress_len },
535 ..pending_seek
536 });
537 }
538 PendingSeekState::RestoringJumpForward { decompress_len: 0 } => {
539 // We finished getting back to the initial position!
540 // Clear the pending seek
541
542 assert_eq!(pending_seek.initial_pos, this.reader.current_pos);
543 *this.pending_seek = None;
544 }
545 PendingSeekState::RestoringJumpForward { decompress_len } => {
546 // We have to decompress some data to reach the initial
547 // position
548
549 // Try to decompress some data from the underlying
550 // reader
551 let result = ready!(this.reader.as_mut().poll_fill_buf(cx));
552 let filled = match result {
553 Ok(filled) => filled,
554 Err(error) => {
555 // Failed to decompress from the underlying
556 // reader, so clear the pending seek and bail
557 *this.pending_seek = None;
558 return std::task::Poll::Ready(Err(std::io::Error::other(format!(
559 "failed to cancel in-progress zstd seek: {error}"
560 ))));
561 }
562 };
563
564 if filled.is_empty() {
565 // The underlying reader didn't give us any data, which
566 // means we hit EOF while trying to get back to the
567 // initial position. Clear the pending seek and bail
568
569 *this.pending_seek = None;
570 return std::task::Poll::Ready(Err(std::io::Error::new(
571 std::io::ErrorKind::UnexpectedEof,
572 "reached eof while trying to cancel in-progress zstd seek",
573 )));
574 }
575
576 // Consume as much data as we can to try and reach
577 // the total data to decompress
578 let filled_len_u64: u64 = filled.len().try_into().unwrap();
579 let jump_len = filled_len_u64.min(decompress_len);
580 let jump_len_usize: usize = jump_len.try_into().unwrap();
581 this.reader.as_mut().consume(jump_len_usize);
582
583 // Update the state based on how much more data
584 // we have left to decompress
585 *this.pending_seek = Some(PendingSeek {
586 state: PendingSeekState::RestoringJumpForward {
587 decompress_len: decompress_len - jump_len,
588 },
589 ..pending_seek
590 });
591 }
592 }
593 }
594 }
595
596 /// If a seek operation was started with [`futures::AsyncSeek::poll_seek`]
597 /// but wasn't polled to completion, "undo" the seek by seeking
598 /// back to where we were in the zstd stream.
599 #[cfg(feature = "futures")]
600 fn poll_cancel_seek_futures(
601 self: std::pin::Pin<&mut Self>,
602 cx: &mut std::task::Context<'_>,
603 ) -> std::task::Poll<std::io::Result<()>>
604 where
605 R: futures::AsyncBufRead + futures::AsyncSeek,
606 {
607 use crate::buffer::Buffer as _;
608 use futures::AsyncBufRead as _;
609
610 let mut this = self.project();
611
612 // Iterate until `self.pending_seek` is unset. Each iteration
613 // should make some progress based on the current pending seek
614 // state (or should return `Poll::Pending`)
615 loop {
616 let Some(pending_seek) = *this.pending_seek else {
617 // No pending seek, which means we're done!
618 return std::task::Poll::Ready(Ok(()));
619 };
620
621 match pending_seek.state {
622 PendingSeekState::Starting => {
623 // Seek just started. There's nothing to undo, so
624 // just clear the pending seek
625 *this.pending_seek = None;
626 }
627 PendingSeekState::SeekingToLastFrame { .. }
628 | PendingSeekState::JumpingToEnd { .. }
629 | PendingSeekState::SeekingToTarget { .. }
630 | PendingSeekState::SeekingToFrame { .. }
631 | PendingSeekState::JumpingForward { .. } => {
632 // Seek is in progress
633
634 // Consume any leftover data in `self.buffer`. This ensures
635 // that the current position is in line with the
636 // underlying decoder
637 let consumable = this.reader.buffer.uncommitted().len();
638 this.reader.as_mut().consume(consumable);
639
640 // Determine what we need to do to reach the target position
641 let seek = this
642 .reader
643 .decoder
644 .prepare_seek_to_decompressed_pos(pending_seek.initial_pos);
645
646 if let Some(frame) = seek.seek_to_frame_start {
647 // We need to seek to the start of a frame
648
649 // Transition the state to indicate we're seeking
650 // to the start of a frame
651 *this.pending_seek = Some(PendingSeek {
652 state: PendingSeekState::RestoringSeekToFrame {
653 frame,
654 decompress_len: seek.decompress_len,
655 },
656 ..pending_seek
657 });
658 } else {
659 // We just need to keep decompressing to reach the
660 // target position
661
662 // Transition to a state to indicate how many
663 // bytes we need to consume
664 *this.pending_seek = Some(PendingSeek {
665 state: PendingSeekState::RestoringJumpForward {
666 decompress_len: seek.decompress_len,
667 },
668 ..pending_seek
669 });
670 }
671 }
672 PendingSeekState::RestoringSeekToFrame {
673 frame,
674 decompress_len,
675 } => {
676 // We're in the process of restoring to the previous
677 // seek position, and need to seek to the start of
678 // a frame
679
680 let reader = this.reader.as_mut().project();
681
682 // Poll until we finish seeking the underlying reader
683 let result = ready!(reader
684 .reader
685 .poll_seek(cx, std::io::SeekFrom::Start(frame.compressed_pos)));
686
687 match result {
688 Ok(_) => {}
689 Err(error) => {
690 // Seeking the underlying reader failed, so
691 // clear the pending seek and bail
692 *this.pending_seek = None;
693 return std::task::Poll::Ready(Err(std::io::Error::other(format!(
694 "failed to cancel in-progress zstd seek: {error}"
695 ))));
696 }
697 }
698
699 // Update our internal position to align with the start of the frame
700 *reader.current_pos = frame.decompressed_pos;
701
702 // Update the decoder based on what frame we're now at
703 let result = reader.decoder.seeked_to_frame(frame);
704 match result {
705 Ok(_) => {}
706 Err(error) => {
707 // Error from the decoder, so clear the pending
708 // seek and bail
709 *this.pending_seek = None;
710 return std::task::Poll::Ready(Err(std::io::Error::other(format!(
711 "failed to cancel in-progress zstd seek: {error}"
712 ))));
713 }
714 }
715
716 // Update our state to decompress as much data as we
717 // need to reach the initial position again
718 *this.pending_seek = Some(PendingSeek {
719 state: PendingSeekState::RestoringJumpForward { decompress_len },
720 ..pending_seek
721 });
722 }
723 PendingSeekState::RestoringJumpForward { decompress_len: 0 } => {
724 // We finished getting back to the initial position!
725 // Clear the pending seek
726
727 assert_eq!(pending_seek.initial_pos, this.reader.current_pos);
728 *this.pending_seek = None;
729 }
730 PendingSeekState::RestoringJumpForward { decompress_len } => {
731 // We have to decompress some data to reach the initial
732 // position
733
734 // Try to decompress some data from the underlying
735 // reader
736 let result = ready!(this.reader.as_mut().poll_fill_buf(cx));
737 let filled = match result {
738 Ok(filled) => filled,
739 Err(error) => {
740 // Failed to decompress from the underlying
741 // reader, so clear the pending seek and bail
742 *this.pending_seek = None;
743 return std::task::Poll::Ready(Err(std::io::Error::other(format!(
744 "failed to cancel in-progress zstd seek: {error}"
745 ))));
746 }
747 };
748
749 if filled.is_empty() {
750 // The underlying reader didn't give us any data, which
751 // means we hit EOF while trying to get back to the
752 // initial position. Clear the pending seek and bail
753
754 *this.pending_seek = None;
755 return std::task::Poll::Ready(Err(std::io::Error::new(
756 std::io::ErrorKind::UnexpectedEof,
757 "reached eof while trying to cancel in-progress zstd seek",
758 )));
759 }
760
761 // Consume as much data as we can to try and reach
762 // the total data to decompress
763 let filled_len_u64: u64 = filled.len().try_into().unwrap();
764 let jump_len = filled_len_u64.min(decompress_len);
765 let jump_len_usize: usize = jump_len.try_into().unwrap();
766 this.reader.as_mut().consume(jump_len_usize);
767
768 // Update the state based on how much more data
769 // we have left to decompress
770 *this.pending_seek = Some(PendingSeek {
771 state: PendingSeekState::RestoringJumpForward {
772 decompress_len: decompress_len - jump_len,
773 },
774 ..pending_seek
775 });
776 }
777 }
778 }
779 }
780}
781
782#[cfg(feature = "tokio")]
783impl<R> tokio::io::AsyncBufRead for AsyncZstdSeekableReader<'_, R>
784where
785 R: tokio::io::AsyncBufRead + tokio::io::AsyncSeek,
786{
787 fn poll_fill_buf(
788 mut self: std::pin::Pin<&mut Self>,
789 cx: &mut std::task::Context<'_>,
790 ) -> std::task::Poll<std::io::Result<&[u8]>> {
791 // Cancel any in-progress seeks
792 ready!(self.as_mut().poll_cancel_seek_tokio(cx))?;
793
794 // Defer to the underlying implementation
795 let this = self.project();
796 this.reader.poll_fill_buf(cx)
797 }
798
799 fn consume(self: std::pin::Pin<&mut Self>, amt: usize) {
800 let this = self.project();
801
802 // Ensure we aren't trying to seek before removing any
803 // data from the buffer
804 assert!(
805 this.pending_seek.is_none(),
806 "tried to consume from buffer while seeking"
807 );
808
809 // Defer to the underlying implementation
810 this.reader.consume(amt);
811 }
812}
813
814#[cfg(feature = "tokio")]
815impl<R> tokio::io::AsyncRead for AsyncZstdSeekableReader<'_, R>
816where
817 R: tokio::io::AsyncBufRead + tokio::io::AsyncSeek,
818{
819 fn poll_read(
820 mut self: std::pin::Pin<&mut Self>,
821 cx: &mut std::task::Context<'_>,
822 buf: &mut tokio::io::ReadBuf<'_>,
823 ) -> std::task::Poll<std::io::Result<()>> {
824 use tokio::io::AsyncBufRead as _;
825
826 // Decode some data from the underlying reader
827 let filled = ready!(self.as_mut().poll_fill_buf(cx))?;
828
829 // Get some of the decoded data, capped to `buf`'s length
830 let consumable = filled.len().min(buf.remaining());
831
832 // Copy the decoded data to `buf`
833 buf.put_slice(&filled[..consumable]);
834
835 // Consume the copied data
836 self.consume(consumable);
837
838 std::task::Poll::Ready(Ok(()))
839 }
840}
841
842#[cfg(feature = "tokio")]
843impl<R> tokio::io::AsyncSeek for AsyncZstdSeekableReader<'_, R>
844where
845 R: tokio::io::AsyncBufRead + tokio::io::AsyncSeek,
846{
847 fn start_seek(
848 self: std::pin::Pin<&mut Self>,
849 position: std::io::SeekFrom,
850 ) -> std::io::Result<()> {
851 let mut this = self.project();
852
853 // Ensure there isn't another seek in progress first
854 if this.pending_seek.is_some() {
855 return Err(std::io::Error::other("seek already in progress"));
856 }
857
858 // Transition to the "starting" seek state
859 *this.pending_seek = Some(PendingSeek {
860 initial_pos: this.reader.as_mut().current_pos,
861 seek: position,
862 state: PendingSeekState::Starting,
863 });
864 Ok(())
865 }
866
867 fn poll_complete(
868 mut self: std::pin::Pin<&mut Self>,
869 cx: &mut std::task::Context<'_>,
870 ) -> std::task::Poll<std::io::Result<u64>> {
871 use crate::buffer::Buffer as _;
872 use tokio::io::AsyncBufRead as _;
873
874 // Iterate until `self.pending_seek` is unset. Each iteration
875 // should make some progress based on the current pending seek
876 // state (or should return `Poll::Pending`)
877 loop {
878 let mut this = self.as_mut().project();
879
880 let Some(pending_seek) = *this.pending_seek else {
881 return std::task::Poll::Ready(Ok(this.reader.current_pos));
882 };
883
884 match pending_seek.state {
885 PendingSeekState::Starting => {
886 // Seek is starting. The first step is to determine
887 // the seek target relative to the start of the stream
888
889 match pending_seek.seek {
890 std::io::SeekFrom::Start(offset) => {
891 // The offset is already relatve to the start,
892 // so transition to the state to start seeking
893 *this.pending_seek = Some(PendingSeek {
894 state: PendingSeekState::SeekingToTarget { target_pos: offset },
895 ..pending_seek
896 });
897 }
898 std::io::SeekFrom::End(end_offset) => {
899 // To determine the offset relative to the start,
900 // we first need to reach the end of the stream.
901
902 // Determine the best way to reach the last
903 // position we know about in the stream.
904 let seek = this.reader.decoder.prepare_seek_to_last_known_pos();
905
906 if let Some(frame) = seek.seek_to_frame_start {
907 // We have a frame to seek to
908
909 // Start a job to seek to the start
910 // of the last known frame
911 let reader = this.reader.as_mut().project().reader;
912 let result = reader
913 .start_seek(std::io::SeekFrom::Start(frame.compressed_pos));
914
915 match result {
916 Ok(_) => {}
917 Err(error) => {
918 // Trying to seek the underlying reader
919 // failed, so clear the pending seek and bail
920 *this.pending_seek = None;
921 return std::task::Poll::Ready(Err(std::io::Error::other(
922 format!(
923 "failed to cancel in-progress zstd seek: {error}"
924 ),
925 )));
926 }
927 }
928
929 // Transition to the state to seek before
930 // jumping to the end of the stream
931 *this.pending_seek = Some(PendingSeek {
932 state: PendingSeekState::SeekingToLastFrame {
933 frame,
934 end_offset,
935 },
936 ..pending_seek
937 })
938 } else {
939 // No need to seek, so transition to the
940 // state to jump to the end of the stream
941 *this.pending_seek = Some(PendingSeek {
942 state: PendingSeekState::JumpingToEnd { end_offset },
943 ..pending_seek
944 });
945 }
946 }
947 std::io::SeekFrom::Current(offset) => {
948 // Compute the offset relative to the current position
949 let offset = this.reader.current_pos.checked_add_signed(offset);
950 let offset = match offset {
951 Some(offset) => offset,
952 None => {
953 // Offset overflowed, so clear the pending
954 // seek and return an error
955 *this.pending_seek = None;
956 return std::task::Poll::Ready(Err(std::io::Error::other(
957 "invalid seek offset",
958 )));
959 }
960 };
961
962 // Transition to the state to start seeking based
963 // on the computed offset
964 *this.pending_seek = Some(PendingSeek {
965 state: PendingSeekState::SeekingToTarget { target_pos: offset },
966 ..pending_seek
967 });
968 }
969 }
970 }
971 PendingSeekState::SeekingToLastFrame { end_offset, frame } => {
972 // We're seeking to the last (known) frame in the stream
973 // before jumping to the end
974
975 let reader = this.reader.as_mut().project();
976
977 // Wait for the seek on the underlying reader to complete
978 let result = ready!(reader.reader.poll_complete(cx));
979 match result {
980 Ok(_) => {}
981 Err(error) => {
982 // Seeking the underlying reader failed,
983 // so clear the in-progress seek and bail
984 *this.pending_seek = None;
985 return std::task::Poll::Ready(Err(error));
986 }
987 };
988
989 // Update the decoder based on what frame we're now at
990 let result = reader.decoder.seeked_to_frame(frame);
991 match result {
992 Ok(_) => {}
993 Err(error) => {
994 // The decoder failed, so clear the
995 // in-progress seek and bail
996 *this.pending_seek = None;
997 return std::task::Poll::Ready(Err(error));
998 }
999 }
1000
1001 // Update our internal position to align with the
1002 // start of the frame
1003 *reader.current_pos = frame.decompressed_pos;
1004
1005 // Clear the buffer
1006 reader.buffer.clear();
1007
1008 // Seek complete, so transition states to jump to
1009 // the end of the stream
1010 *this.pending_seek = Some(PendingSeek {
1011 state: PendingSeekState::JumpingToEnd { end_offset },
1012 ..pending_seek
1013 });
1014 }
1015 PendingSeekState::JumpingToEnd { end_offset } => {
1016 // Seek target is relative to the end of the stream, so
1017 // we need to jump to the end of the stream before
1018 // determing the target position
1019
1020 // Try to jump to the end of stream
1021 let result = ready!(this.reader.as_mut().poll_jump_to_end_tokio(cx));
1022 match result {
1023 Ok(_) => {}
1024 Err(error) => {
1025 // Jumping to the end failed, so cancel the
1026 // in-progress seek and bail
1027 *this.pending_seek = None;
1028 return std::task::Poll::Ready(Err(error));
1029 }
1030 };
1031
1032 // Now we're at the end of the stream, so we can now
1033 // compute the target position
1034 let target_pos = this.reader.current_pos.checked_add_signed(end_offset);
1035 let target_pos = match target_pos {
1036 Some(target_pos) => target_pos,
1037 None => {
1038 // Target position overflowed, so cancel the
1039 // in-progress seek and bail
1040 *this.pending_seek = None;
1041 return std::task::Poll::Ready(Err(std::io::Error::other(
1042 "invalid seek offset",
1043 )));
1044 }
1045 };
1046
1047 // Transition to the state to start seeking based
1048 // on the computed offset
1049 *this.pending_seek = Some(PendingSeek {
1050 state: PendingSeekState::SeekingToTarget { target_pos },
1051 ..pending_seek
1052 });
1053 }
1054 PendingSeekState::SeekingToTarget { target_pos } => {
1055 // We now know the relative position we're seeking to
1056
1057 // Consume any leftover data in `self.buffer`. This ensures
1058 // that the current position is in line with the
1059 // underlying decoder
1060 let consumable = this.reader.buffer.uncommitted().len();
1061 this.reader.as_mut().consume(consumable);
1062
1063 // Determine what we need to do to reach the target position
1064 let seek = this
1065 .reader
1066 .decoder
1067 .prepare_seek_to_decompressed_pos(target_pos);
1068
1069 if let Some(frame) = seek.seek_to_frame_start {
1070 // We need to seek to the start of a frame
1071
1072 // Transition to the state so we poll until
1073 // the seek completes
1074 *this.pending_seek = Some(PendingSeek {
1075 state: PendingSeekState::SeekingToFrame {
1076 target_pos,
1077 frame,
1078 decompress_len: seek.decompress_len,
1079 },
1080 ..pending_seek
1081 });
1082
1083 let reader = this.reader.as_mut().project().reader;
1084
1085 // Start a job to seek the underlying reader
1086 let result =
1087 reader.start_seek(std::io::SeekFrom::Start(frame.compressed_pos));
1088
1089 match result {
1090 Ok(_) => {}
1091 Err(error) => {
1092 // Trying to seek the underlying reader
1093 // failed, so clear the in-progress seek
1094 // and bail
1095 *this.pending_seek = None;
1096 return std::task::Poll::Ready(Err(error));
1097 }
1098 }
1099 } else {
1100 // We need to keep decoding bytes to reach the
1101 // target position
1102
1103 // Transition to the state so that we can keep
1104 // decompressing until reaching the target position
1105 *this.pending_seek = Some(PendingSeek {
1106 state: PendingSeekState::JumpingForward {
1107 target_pos,
1108 decompress_len: seek.decompress_len,
1109 },
1110 ..pending_seek
1111 });
1112 }
1113 }
1114 PendingSeekState::SeekingToFrame {
1115 target_pos,
1116 frame,
1117 decompress_len,
1118 } => {
1119 // We're seeking to the start of a frame
1120
1121 let reader = this.reader.as_mut().project();
1122
1123 // Poll until the underlying reader finishes seeking
1124 let result = ready!(reader.reader.poll_complete(cx));
1125 match result {
1126 Ok(_) => {}
1127 Err(error) => {
1128 // Seeking the underlying reader failed,
1129 // so clear the in-progress seek and bail
1130 *this.pending_seek = None;
1131 return std::task::Poll::Ready(Err(error));
1132 }
1133 };
1134
1135 // Update the decoder based on what frame we're now at
1136 let result = reader.decoder.seeked_to_frame(frame);
1137 match result {
1138 Ok(_) => {}
1139 Err(error) => {
1140 // The decoder failed, so clear the
1141 // in-progress seek and bail
1142 *this.pending_seek = None;
1143 return std::task::Poll::Ready(Err(error));
1144 }
1145 }
1146
1147 // Update our internal position to align with the
1148 // start of the frame
1149 *reader.current_pos = frame.decompressed_pos;
1150
1151 // Seek complete, so transition states to decompress
1152 // until reaching the target position
1153 *this.pending_seek = Some(PendingSeek {
1154 state: PendingSeekState::JumpingForward {
1155 target_pos,
1156 decompress_len,
1157 },
1158 ..pending_seek
1159 });
1160 }
1161 PendingSeekState::JumpingForward {
1162 target_pos,
1163 decompress_len: 0,
1164 } => {
1165 // No more bytes to decompress, so at the target position!
1166 // Clear the pending seek
1167 assert_eq!(target_pos, this.reader.current_pos);
1168 *this.pending_seek = None;
1169 }
1170 PendingSeekState::JumpingForward {
1171 target_pos,
1172 decompress_len,
1173 } => {
1174 // We have some bytes to decompress before reaching
1175 // the target position
1176
1177 // Try to decompress some data from the underlying
1178 // reader
1179 let result = ready!(this.reader.as_mut().poll_fill_buf(cx));
1180 let filled = match result {
1181 Ok(filled) => filled,
1182 Err(error) => {
1183 // Failed to decompress from the underlying
1184 // reader, so clear the pending seek and bail
1185 *this.pending_seek = None;
1186 return std::task::Poll::Ready(Err(error));
1187 }
1188 };
1189
1190 if filled.is_empty() {
1191 // The underlying reader didn't give us any data, which
1192 // means we hit EOF while trying to seek. Clear the
1193 // pending seek and bail
1194 *this.pending_seek = None;
1195 return std::task::Poll::Ready(Err(std::io::Error::new(
1196 std::io::ErrorKind::UnexpectedEof,
1197 "reached eof while trying to decode to offset",
1198 )));
1199 }
1200
1201 // Consume as much data as we can to try and reach
1202 // the total data to decompress
1203 let filled_len_u64: u64 = filled.len().try_into().unwrap();
1204 let jump_len = filled_len_u64.min(decompress_len);
1205 let jump_len_usize: usize = jump_len.try_into().unwrap();
1206 this.reader.as_mut().consume(jump_len_usize);
1207
1208 // Update the state based on how much more data
1209 // we have left to decompress
1210 *this.pending_seek = Some(PendingSeek {
1211 state: PendingSeekState::JumpingForward {
1212 target_pos,
1213 decompress_len: decompress_len - jump_len,
1214 },
1215 ..pending_seek
1216 });
1217 }
1218 PendingSeekState::RestoringSeekToFrame { .. }
1219 | PendingSeekState::RestoringJumpForward { .. } => {
1220 // The seek was cancelled, so poll until the
1221 // cancellation finished then return an error
1222 ready!(self.as_mut().poll_cancel_seek_tokio(cx))?;
1223 return std::task::Poll::Ready(Err(std::io::Error::other("seek cancelled")));
1224 }
1225 }
1226 }
1227 }
1228}
1229
1230#[cfg(feature = "futures")]
1231impl<R> futures::AsyncBufRead for AsyncZstdSeekableReader<'_, R>
1232where
1233 R: futures::AsyncBufRead + futures::AsyncSeek,
1234{
1235 fn poll_fill_buf(
1236 mut self: std::pin::Pin<&mut Self>,
1237 cx: &mut std::task::Context<'_>,
1238 ) -> std::task::Poll<std::io::Result<&[u8]>> {
1239 // Cancel any in-progress seeks
1240 ready!(self.as_mut().poll_cancel_seek_futures(cx))?;
1241
1242 // Defer to the underlying implementation
1243 let this = self.project();
1244 this.reader.poll_fill_buf(cx)
1245 }
1246
1247 fn consume(self: std::pin::Pin<&mut Self>, amt: usize) {
1248 let this = self.project();
1249
1250 // Ensure we aren't trying to seek before removing any
1251 // data from the buffer
1252 assert!(
1253 this.pending_seek.is_none(),
1254 "tried to consume from buffer while seeking"
1255 );
1256
1257 // Defer to the underlying implementation
1258 this.reader.consume(amt);
1259 }
1260}
1261
1262#[cfg(feature = "futures")]
1263impl<R> futures::AsyncRead for AsyncZstdSeekableReader<'_, R>
1264where
1265 R: futures::AsyncBufRead + futures::AsyncSeek,
1266{
1267 fn poll_read(
1268 mut self: std::pin::Pin<&mut Self>,
1269 cx: &mut std::task::Context<'_>,
1270 buf: &mut [u8],
1271 ) -> std::task::Poll<std::io::Result<usize>> {
1272 use futures::AsyncBufRead as _;
1273
1274 // Decode some data from the underlying reader
1275 let filled = ready!(self.as_mut().poll_fill_buf(cx))?;
1276
1277 // Get some of the decoded data, capped to `buf`'s length
1278 let consumable = filled.len().min(buf.len());
1279
1280 // Copy the decoded data to `buf`
1281 buf[..consumable].copy_from_slice(&filled[..consumable]);
1282
1283 // Consume the copied data
1284 self.consume(consumable);
1285
1286 std::task::Poll::Ready(Ok(consumable))
1287 }
1288}
1289
1290#[cfg(feature = "futures")]
1291impl<R> futures::AsyncSeek for AsyncZstdSeekableReader<'_, R>
1292where
1293 R: futures::AsyncBufRead + futures::AsyncSeek,
1294{
1295 fn poll_seek(
1296 mut self: std::pin::Pin<&mut Self>,
1297 cx: &mut std::task::Context<'_>,
1298 position: std::io::SeekFrom,
1299 ) -> std::task::Poll<std::io::Result<u64>> {
1300 use crate::buffer::Buffer as _;
1301 use futures::io::AsyncBufRead as _;
1302
1303 // Iterate until the seek is complete. Each iteration should
1304 // make some progress based on the current pending seek state
1305 // (or should return `Poll::Pending`)
1306 loop {
1307 let this = self.as_mut().project();
1308
1309 let pending_seek = match *this.pending_seek {
1310 Some(pending_seek) if pending_seek.seek == position => {
1311 // We're already seeking with the same position,
1312 // so keep going from where we left off
1313 pending_seek
1314 }
1315 _ => {
1316 // If there's another seek operation in progress,
1317 // cancel it first
1318 ready!(self.as_mut().poll_cancel_seek_futures(cx))?;
1319
1320 let this = self.as_mut().project();
1321
1322 // Start a new seek operation
1323 let pending_seek = PendingSeek {
1324 initial_pos: this.reader.current_pos,
1325 seek: position,
1326 state: PendingSeekState::Starting,
1327 };
1328 *this.pending_seek = Some(pending_seek);
1329 pending_seek
1330 }
1331 };
1332
1333 let mut this = self.as_mut().project();
1334
1335 match pending_seek.state {
1336 PendingSeekState::Starting => {
1337 // Seek is starting. The first step is to determine
1338 // the seek target relative to the start of the stream
1339
1340 match pending_seek.seek {
1341 std::io::SeekFrom::Start(offset) => {
1342 // The offset is already relatve to the start,
1343 // so transition to the state to start seeking
1344 *this.pending_seek = Some(PendingSeek {
1345 state: PendingSeekState::SeekingToTarget { target_pos: offset },
1346 ..pending_seek
1347 });
1348 }
1349 std::io::SeekFrom::End(end_offset) => {
1350 // To determine the offset relative to the start,
1351 // we first need to reach the end of the stream.
1352
1353 // Determine the best way to reach the last
1354 // position we know about in the stream.
1355 let seek = this.reader.decoder.prepare_seek_to_last_known_pos();
1356
1357 if let Some(frame) = seek.seek_to_frame_start {
1358 // We have a frame to seek to, so
1359 // transition to the state to seek before
1360 // jumping to the end of the stream
1361 *this.pending_seek = Some(PendingSeek {
1362 state: PendingSeekState::SeekingToLastFrame {
1363 frame,
1364 end_offset,
1365 },
1366 ..pending_seek
1367 })
1368 } else {
1369 // No need to seek, so transition to the
1370 // state to jump to the end of the stream
1371 *this.pending_seek = Some(PendingSeek {
1372 state: PendingSeekState::JumpingToEnd { end_offset },
1373 ..pending_seek
1374 });
1375 }
1376 }
1377 std::io::SeekFrom::Current(offset) => {
1378 // Compute the offset relative to the current position
1379 let offset = this.reader.current_pos.checked_add_signed(offset);
1380 let offset = match offset {
1381 Some(offset) => offset,
1382 None => {
1383 // Offset overflowed, so clear the pending
1384 // seek and return an error
1385 *this.pending_seek = None;
1386 return std::task::Poll::Ready(Err(std::io::Error::other(
1387 "invalid seek offset",
1388 )));
1389 }
1390 };
1391
1392 // Transition to the state to start seeking based
1393 // on the computed offset
1394 *this.pending_seek = Some(PendingSeek {
1395 state: PendingSeekState::SeekingToTarget { target_pos: offset },
1396 ..pending_seek
1397 });
1398 }
1399 }
1400 }
1401 PendingSeekState::SeekingToLastFrame { end_offset, frame } => {
1402 // We're seeking to the last (known) frame in the stream
1403 // before jumping to the end
1404
1405 let reader = this.reader.as_mut().project();
1406
1407 // Seek the underlying reader
1408 let result = ready!(reader
1409 .reader
1410 .poll_seek(cx, std::io::SeekFrom::Start(frame.compressed_pos)));
1411 match result {
1412 Ok(_) => {}
1413 Err(error) => {
1414 // Seeking the underlying reader failed,
1415 // so clear the in-progress seek and bail
1416 *this.pending_seek = None;
1417 return std::task::Poll::Ready(Err(error));
1418 }
1419 };
1420
1421 // Update the decoder based on what frame we're now at
1422 let result = reader.decoder.seeked_to_frame(frame);
1423 match result {
1424 Ok(_) => {}
1425 Err(error) => {
1426 // The decoder failed, so clear the
1427 // in-progress seek and bail
1428 *this.pending_seek = None;
1429 return std::task::Poll::Ready(Err(error));
1430 }
1431 }
1432
1433 // Update our internal position to align with the
1434 // start of the frame
1435 *reader.current_pos = frame.decompressed_pos;
1436
1437 // Clear the buffer
1438 reader.buffer.clear();
1439
1440 // Seek complete, so transition states to jump to
1441 // the end of the stream
1442 *this.pending_seek = Some(PendingSeek {
1443 state: PendingSeekState::JumpingToEnd { end_offset },
1444 ..pending_seek
1445 });
1446 }
1447 PendingSeekState::JumpingToEnd { end_offset } => {
1448 // Seek target is relative to the end of the stream, so
1449 // we need to jump to the end of the stream before
1450 // determing the target position
1451
1452 // Try to jump to the end of stream
1453 let result = ready!(this.reader.as_mut().poll_jump_to_end_futures(cx));
1454 match result {
1455 Ok(_) => {}
1456 Err(error) => {
1457 // Jumping to the end failed, so cancel the
1458 // in-progress seek and bail
1459 *this.pending_seek = None;
1460 return std::task::Poll::Ready(Err(error));
1461 }
1462 };
1463
1464 // Now we're at the end of the stream, so we can now
1465 // compute the target position
1466 let target_pos = this.reader.current_pos.checked_add_signed(end_offset);
1467 let target_pos = match target_pos {
1468 Some(target_pos) => target_pos,
1469 None => {
1470 // Target position overflowed, so cancel the
1471 // in-progress seek and bail
1472 *this.pending_seek = None;
1473 return std::task::Poll::Ready(Err(std::io::Error::other(
1474 "invalid seek offset",
1475 )));
1476 }
1477 };
1478
1479 // Transition to the state to start seeking based
1480 // on the computed offset
1481 *this.pending_seek = Some(PendingSeek {
1482 state: PendingSeekState::SeekingToTarget { target_pos },
1483 ..pending_seek
1484 });
1485 }
1486 PendingSeekState::SeekingToTarget { target_pos } => {
1487 // We now know the relative position we're seeking to
1488
1489 // Consume any leftover data in `self.buffer`. This ensures
1490 // that the current position is in line with the
1491 // underlying decoder
1492 let consumable = this.reader.buffer.uncommitted().len();
1493 this.reader.as_mut().consume(consumable);
1494
1495 // Determine what we need to do to reach the target position
1496 let seek = this
1497 .reader
1498 .decoder
1499 .prepare_seek_to_decompressed_pos(target_pos);
1500
1501 if let Some(frame) = seek.seek_to_frame_start {
1502 // We need to seek to the start of a frame
1503
1504 // Transition to the state so we poll until
1505 // the seek completes
1506 *this.pending_seek = Some(PendingSeek {
1507 state: PendingSeekState::SeekingToFrame {
1508 target_pos,
1509 frame,
1510 decompress_len: seek.decompress_len,
1511 },
1512 ..pending_seek
1513 });
1514 } else {
1515 // We need to keep decoding bytes to reach the
1516 // target position
1517
1518 // Transition to the state so that we can keep
1519 // decompressing until reaching the target position
1520 *this.pending_seek = Some(PendingSeek {
1521 state: PendingSeekState::JumpingForward {
1522 target_pos,
1523 decompress_len: seek.decompress_len,
1524 },
1525 ..pending_seek
1526 });
1527 }
1528 }
1529 PendingSeekState::SeekingToFrame {
1530 target_pos,
1531 frame,
1532 decompress_len,
1533 } => {
1534 // We're seeking to the start of a frame
1535
1536 let reader = this.reader.as_mut().project();
1537
1538 // Seek the underlying reader
1539 let result = ready!(reader
1540 .reader
1541 .poll_seek(cx, std::io::SeekFrom::Start(frame.compressed_pos)));
1542 match result {
1543 Ok(_) => {}
1544 Err(error) => {
1545 // Seeking the underlying reader failed,
1546 // so clear the in-progress seek and bail
1547 *this.pending_seek = None;
1548 return std::task::Poll::Ready(Err(error));
1549 }
1550 };
1551
1552 // Update the decoder based on what frame we're now at
1553 let result = reader.decoder.seeked_to_frame(frame);
1554 match result {
1555 Ok(_) => {}
1556 Err(error) => {
1557 // The decoder failed, so clear the
1558 // in-progress seek and bail
1559 *this.pending_seek = None;
1560 return std::task::Poll::Ready(Err(error));
1561 }
1562 }
1563
1564 // Update our internal position to align with the
1565 // start of the frame
1566 *reader.current_pos = frame.decompressed_pos;
1567
1568 // Seek complete, so transition states to decompress
1569 // until reaching the target position
1570 *this.pending_seek = Some(PendingSeek {
1571 state: PendingSeekState::JumpingForward {
1572 target_pos,
1573 decompress_len,
1574 },
1575 ..pending_seek
1576 });
1577 }
1578 PendingSeekState::JumpingForward {
1579 target_pos,
1580 decompress_len: 0,
1581 } => {
1582 // No more bytes to decompress, so at the target position!
1583 // Clear the pending seek, then we're done
1584 assert_eq!(target_pos, this.reader.current_pos);
1585 *this.pending_seek = None;
1586 return std::task::Poll::Ready(Ok(this.reader.current_pos));
1587 }
1588 PendingSeekState::JumpingForward {
1589 target_pos,
1590 decompress_len,
1591 } => {
1592 // We have some bytes to decompress before reaching
1593 // the target position
1594
1595 // Try to decompress some data from the underlying
1596 // reader
1597 let result = ready!(this.reader.as_mut().poll_fill_buf(cx));
1598 let filled = match result {
1599 Ok(filled) => filled,
1600 Err(error) => {
1601 // Failed to decompress from the underlying
1602 // reader, so clear the pending seek and bail
1603 *this.pending_seek = None;
1604 return std::task::Poll::Ready(Err(error));
1605 }
1606 };
1607
1608 if filled.is_empty() {
1609 // The underlying reader didn't give us any data, which
1610 // means we hit EOF while trying to seek. Clear the
1611 // pending seek and bail
1612 *this.pending_seek = None;
1613 return std::task::Poll::Ready(Err(std::io::Error::new(
1614 std::io::ErrorKind::UnexpectedEof,
1615 "reached eof while trying to decode to offset",
1616 )));
1617 }
1618
1619 // Consume as much data as we can to try and reach
1620 // the total data to decompress
1621 let filled_len_u64: u64 = filled.len().try_into().unwrap();
1622 let jump_len = filled_len_u64.min(decompress_len);
1623 let jump_len_usize: usize = jump_len.try_into().unwrap();
1624 this.reader.as_mut().consume(jump_len_usize);
1625
1626 // Update the state based on how much more data
1627 // we have left to decompress
1628 *this.pending_seek = Some(PendingSeek {
1629 state: PendingSeekState::JumpingForward {
1630 target_pos,
1631 decompress_len: decompress_len - jump_len,
1632 },
1633 ..pending_seek
1634 });
1635 }
1636 PendingSeekState::RestoringSeekToFrame { .. }
1637 | PendingSeekState::RestoringJumpForward { .. } => {
1638 // The seek was cancelled, so poll until the
1639 // cancellation finished then return an error
1640 ready!(self.as_mut().poll_cancel_seek_futures(cx))?;
1641 return std::task::Poll::Ready(Err(std::io::Error::other("seek cancelled")));
1642 }
1643 }
1644 }
1645 }
1646}
1647
1648/// A builder that builds an [`AsyncZstdReader`] from the provided reader.
1649pub struct ZstdReaderBuilder<R> {
1650 reader: R,
1651 table: ZstdSeekTable,
1652}
1653
1654#[cfg(feature = "tokio")]
1655impl<R> ZstdReaderBuilder<tokio::io::BufReader<R>> {
1656 fn new_tokio(reader: R) -> Self
1657 where
1658 R: tokio::io::AsyncRead,
1659 {
1660 let reader = tokio::io::BufReader::with_capacity(zstd::zstd_safe::DCtx::in_size(), reader);
1661 ZstdReaderBuilder::with_buffered(reader)
1662 }
1663}
1664
1665#[cfg(feature = "futures")]
1666impl<R> ZstdReaderBuilder<futures::io::BufReader<R>> {
1667 fn new_futures(reader: R) -> Self
1668 where
1669 R: futures::AsyncRead,
1670 {
1671 let reader =
1672 futures::io::BufReader::with_capacity(zstd::zstd_safe::DCtx::in_size(), reader);
1673 ZstdReaderBuilder::with_buffered(reader)
1674 }
1675}
1676
1677impl<R> ZstdReaderBuilder<R> {
1678 fn with_buffered(reader: R) -> Self {
1679 ZstdReaderBuilder {
1680 reader,
1681 table: ZstdSeekTable::empty(),
1682 }
1683 }
1684
1685 /// Use the given seek table when seeking the resulting reader. This can
1686 /// greatly speed up seek operations when using a zstd stream that
1687 /// uses the [zstd seekable format].
1688 ///
1689 /// See [`crate::table::read_seek_table`] for reading a seek table.
1690 ///
1691 /// [zstd seekable format]: https://github.com/facebook/zstd/tree/51eb7daf39c8e8a7c338ba214a9d4e2a6a086826/contrib/seekable_format
1692 pub fn with_seek_table(mut self, table: ZstdSeekTable) -> Self {
1693 self.table = table;
1694 self
1695 }
1696
1697 /// Build the reader.
1698 pub fn build(self) -> std::io::Result<AsyncZstdReader<'static, R>> {
1699 let zstd_decoder = zstd::stream::raw::Decoder::new()?;
1700 let buffer = crate::buffer::FixedBuffer::new(vec![0; zstd::zstd_safe::DCtx::out_size()]);
1701 let decoder = ZstdFramedDecoder::new(zstd_decoder, self.table);
1702
1703 Ok(AsyncZstdReader {
1704 reader: self.reader,
1705 decoder,
1706 buffer,
1707 current_pos: 0,
1708 })
1709 }
1710}
1711
1712#[cfg_attr(not(any(feature = "tokio", feature = "futures")), expect(dead_code))]
1713#[derive(Debug, Clone, Copy)]
1714struct PendingSeek {
1715 initial_pos: u64,
1716 seek: std::io::SeekFrom,
1717 state: PendingSeekState,
1718}
1719
1720#[cfg_attr(not(any(feature = "tokio", feature = "futures")), expect(dead_code))]
1721#[derive(Debug, Clone, Copy)]
1722enum PendingSeekState {
1723 Starting,
1724 SeekingToLastFrame {
1725 end_offset: i64,
1726 frame: ZstdFrame,
1727 },
1728 JumpingToEnd {
1729 end_offset: i64,
1730 },
1731 SeekingToTarget {
1732 target_pos: u64,
1733 },
1734 SeekingToFrame {
1735 target_pos: u64,
1736 frame: ZstdFrame,
1737 decompress_len: u64,
1738 },
1739 JumpingForward {
1740 target_pos: u64,
1741 decompress_len: u64,
1742 },
1743 RestoringSeekToFrame {
1744 frame: ZstdFrame,
1745 decompress_len: u64,
1746 },
1747 RestoringJumpForward {
1748 decompress_len: u64,
1749 },
1750}