tokio_ar/
entry.rs

1use crate::header::Header;
2use std::cmp;
3use std::io::{Error, ErrorKind, SeekFrom};
4use std::pin::Pin;
5use std::task::{Context, Poll};
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, ReadBuf};
7
8/// Representation of an archive entry.
9///
10/// `Entry` objects implement the `AsyncRead` trait, and can be used to extract the
11/// data from this archive entry.  If the underlying reader supports the `AsyncSeek`
12/// trait, then the `Entry` object supports `AsyncSeek` as well.
13pub struct Entry<'a, R: 'a + AsyncRead + Unpin> {
14    pub(crate) header: &'a Header,
15    pub(crate) reader: &'a mut R,
16    pub(crate) length: u64,
17    pub(crate) position: u64,
18    pub(crate) unread_counter: &'a mut u64,
19}
20
21impl<'a, R: 'a + AsyncRead + Unpin> Entry<'a, R> {
22    /// Returns the header for this archive entry.
23    pub fn header(&self) -> &Header {
24        self.header
25    }
26}
27
28impl<'a, R: 'a + AsyncRead + Unpin> AsyncRead for Entry<'a, R> {
29    fn poll_read(
30        mut self: Pin<&mut Self>,
31        cx: &mut Context<'_>,
32        buf: &mut ReadBuf<'_>,
33    ) -> Poll<std::io::Result<()>> {
34        debug_assert!(self.position <= self.length);
35        let remaining = self.length.saturating_sub(self.position);
36
37        if remaining == 0 {
38            return Poll::Ready(Ok(()));
39        }
40
41        let max_len = cmp::min(remaining, buf.remaining() as u64);
42
43        // Remember the initial filled length
44        let filled_before = buf.filled().len() as u64;
45
46        match Pin::new(&mut self.reader.take(max_len)).poll_read(cx, buf) {
47            Poll::Ready(Ok(())) => {
48                // Calculate how many bytes were read
49                let filled_after = buf.filled().len() as u64;
50                let bytes_read = filled_after - filled_before;
51
52                // Update position and unread counter
53                self.position += bytes_read;
54                *self.unread_counter -= bytes_read;
55                debug_assert!(self.position <= self.length);
56                Poll::Ready(Ok(()))
57            }
58            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
59            Poll::Pending => Poll::Pending,
60        }
61    }
62}
63
64impl<'a, R: 'a + AsyncRead + AsyncSeek + Unpin> AsyncSeek for Entry<'a, R> {
65    fn start_seek(
66        mut self: Pin<&mut Self>,
67        pos: SeekFrom,
68    ) -> std::io::Result<()> {
69        let delta = match pos {
70            SeekFrom::Start(offset) => offset as i64 - self.position as i64,
71            SeekFrom::End(offset) => {
72                self.length as i64 + offset - self.position as i64
73            }
74            SeekFrom::Current(delta) => delta,
75        };
76        let new_position = self.position as i64 + delta;
77        if new_position < 0 {
78            let msg = format!(
79                "Invalid seek to negative position ({})",
80                new_position
81            );
82            return Err(Error::new(ErrorKind::InvalidInput, msg));
83        }
84        let new_position = new_position as u64;
85        if new_position > self.length {
86            let msg = format!(
87                "Invalid seek to position past end of entry ({} vs. {})",
88                new_position, self.length
89            );
90            return Err(Error::new(ErrorKind::InvalidInput, msg));
91        }
92        Pin::new(&mut self.reader).start_seek(SeekFrom::Current(delta))?;
93        self.position = new_position;
94        *self.unread_counter = self.length - self.position;
95        Ok(())
96    }
97
98    fn poll_complete(
99        mut self: Pin<&mut Self>,
100        cx: &mut Context<'_>,
101    ) -> Poll<std::io::Result<u64>> {
102        match Pin::new(&mut self.reader).poll_complete(cx) {
103            Poll::Ready(result) => Poll::Ready(result.map(|_| self.position)),
104            Poll::Pending => Poll::Pending,
105        }
106    }
107}