1use crate::header::Header;
2use std::cmp;
3use std::io::{Error, ErrorKind, SeekFrom};
4use std::pin::Pin;
5use std::task::{Context, Poll};
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, ReadBuf};
7
8pub struct Entry<'a, R: 'a + AsyncRead + Unpin> {
14 pub(crate) header: &'a Header,
15 pub(crate) reader: &'a mut R,
16 pub(crate) length: u64,
17 pub(crate) position: u64,
18 pub(crate) unread_counter: &'a mut u64,
19}
20
21impl<'a, R: 'a + AsyncRead + Unpin> Entry<'a, R> {
22 pub fn header(&self) -> &Header {
24 self.header
25 }
26}
27
28impl<'a, R: 'a + AsyncRead + Unpin> AsyncRead for Entry<'a, R> {
29 fn poll_read(
30 mut self: Pin<&mut Self>,
31 cx: &mut Context<'_>,
32 buf: &mut ReadBuf<'_>,
33 ) -> Poll<std::io::Result<()>> {
34 debug_assert!(self.position <= self.length);
35 let remaining = self.length.saturating_sub(self.position);
36
37 if remaining == 0 {
38 return Poll::Ready(Ok(()));
39 }
40
41 let max_len = cmp::min(remaining, buf.remaining() as u64);
42
43 let filled_before = buf.filled().len() as u64;
45
46 match Pin::new(&mut self.reader.take(max_len)).poll_read(cx, buf) {
47 Poll::Ready(Ok(())) => {
48 let filled_after = buf.filled().len() as u64;
50 let bytes_read = filled_after - filled_before;
51
52 self.position += bytes_read;
54 *self.unread_counter -= bytes_read;
55 debug_assert!(self.position <= self.length);
56 Poll::Ready(Ok(()))
57 }
58 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
59 Poll::Pending => Poll::Pending,
60 }
61 }
62}
63
64impl<'a, R: 'a + AsyncRead + AsyncSeek + Unpin> AsyncSeek for Entry<'a, R> {
65 fn start_seek(
66 mut self: Pin<&mut Self>,
67 pos: SeekFrom,
68 ) -> std::io::Result<()> {
69 let delta = match pos {
70 SeekFrom::Start(offset) => offset as i64 - self.position as i64,
71 SeekFrom::End(offset) => {
72 self.length as i64 + offset - self.position as i64
73 }
74 SeekFrom::Current(delta) => delta,
75 };
76 let new_position = self.position as i64 + delta;
77 if new_position < 0 {
78 let msg = format!(
79 "Invalid seek to negative position ({})",
80 new_position
81 );
82 return Err(Error::new(ErrorKind::InvalidInput, msg));
83 }
84 let new_position = new_position as u64;
85 if new_position > self.length {
86 let msg = format!(
87 "Invalid seek to position past end of entry ({} vs. {})",
88 new_position, self.length
89 );
90 return Err(Error::new(ErrorKind::InvalidInput, msg));
91 }
92 Pin::new(&mut self.reader).start_seek(SeekFrom::Current(delta))?;
93 self.position = new_position;
94 *self.unread_counter = self.length - self.position;
95 Ok(())
96 }
97
98 fn poll_complete(
99 mut self: Pin<&mut Self>,
100 cx: &mut Context<'_>,
101 ) -> Poll<std::io::Result<u64>> {
102 match Pin::new(&mut self.reader).poll_complete(cx) {
103 Poll::Ready(result) => Poll::Ready(result.map(|_| self.position)),
104 Poll::Pending => Poll::Pending,
105 }
106 }
107}