transmog_async/
reader.rs

1use std::{
2    io,
3    marker::PhantomData,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use bytes::{Buf, BytesMut};
9use futures_core::{ready, Stream};
10use ordered_varint::Variable;
11use tokio::io::{AsyncRead, ReadBuf};
12use transmog::OwnedDeserializer;
13
14/// A wrapper around an asynchronous reader that produces an asynchronous stream
15/// of Transmog-decoded values.
16///
17/// To use, provide a reader that implements [`AsyncRead`], and then use
18/// [`Stream`] to access the deserialized values.
19///
20/// Note that the sender *must* prefix each serialized item with its size
21/// encoded using [`ordered-varint`](ordered_varint).
22#[derive(Debug)]
23pub struct TransmogReader<R, T, F> {
24    format: F,
25    reader: R,
26    pub(crate) buffer: BytesMut,
27    into: PhantomData<T>,
28}
29
30impl<R, T, F> Unpin for TransmogReader<R, T, F> where R: Unpin {}
31
32impl<R, T, F> TransmogReader<R, T, F> {
33    /// Gets a reference to the underlying reader.
34    ///
35    /// It is inadvisable to directly read from the underlying reader.
36    pub fn get_ref(&self) -> &R {
37        &self.reader
38    }
39
40    /// Gets a mutable reference to the underlying reader.
41    ///
42    /// It is inadvisable to directly read from the underlying reader.
43    pub fn get_mut(&mut self) -> &mut R {
44        &mut self.reader
45    }
46
47    /// Returns a reference to the internally buffered data.
48    ///
49    /// This will not attempt to fill the buffer if it is empty.
50    pub fn buffer(&self) -> &[u8] {
51        &self.buffer[..]
52    }
53
54    /// Unwraps this `TransmogReader`, returning the underlying reader.
55    ///
56    /// Note that any leftover data in the internal buffer is lost.
57    pub fn into_inner(self) -> R {
58        self.reader
59    }
60}
61
62impl<R, T, F> TransmogReader<R, T, F> {
63    /// Returns a new instance that reads `format`-encoded data for `reader`.
64    pub fn new(reader: R, format: F) -> Self {
65        TransmogReader {
66            format,
67            buffer: BytesMut::with_capacity(8192),
68            reader,
69            into: PhantomData,
70        }
71    }
72
73    /// Returns a new instance that reads `format`-encoded data for `R::default()`.
74    pub fn default_for(format: F) -> Self
75    where
76        R: Default,
77    {
78        Self::new(R::default(), format)
79    }
80}
81
82impl<R, T, F> Stream for TransmogReader<R, T, F>
83where
84    R: AsyncRead + Unpin,
85    F: OwnedDeserializer<T>,
86{
87    type Item = Result<T, F::Error>;
88    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
89        loop {
90            let fill_result = ready!(self
91                .as_mut()
92                .fill(cx, 9)
93                .map_err(<F::Error as From<std::io::Error>>::from))?;
94
95            let mut buf_reader = &self.buffer[..];
96            let buffer_start = buf_reader.as_ptr() as usize;
97            if let Ok(message_size) = u64::decode_variable(&mut buf_reader) {
98                let header_len = buf_reader.as_ptr() as usize - buffer_start;
99                let target_buffer_size = usize::try_from(message_size).unwrap() + header_len;
100
101                ready!(self
102                    .as_mut()
103                    .fill(cx, target_buffer_size)
104                    .map_err(<F::Error as From<std::io::Error>>::from))?;
105
106                if self.buffer.len() >= target_buffer_size {
107                    let message = self
108                        .format
109                        .deserialize_owned(&self.buffer[header_len..target_buffer_size])
110                        .unwrap();
111                    self.buffer.advance(target_buffer_size);
112                    break Poll::Ready(Some(Ok(message)));
113                }
114            } else if let ReadResult::Eof = fill_result {
115                break Poll::Ready(None);
116            }
117        }
118    }
119}
120
121#[derive(Debug)]
122enum ReadResult {
123    ReceivedData,
124    Eof,
125}
126
127impl<R, T, F> TransmogReader<R, T, F>
128where
129    R: AsyncRead + Unpin,
130{
131    fn fill(
132        mut self: Pin<&mut Self>,
133        cx: &mut Context<'_>,
134        target_size: usize,
135    ) -> Poll<Result<ReadResult, io::Error>> {
136        if self.buffer.len() >= target_size {
137            // we already have the bytes we need!
138            return Poll::Ready(Ok(ReadResult::ReceivedData));
139        }
140
141        // make sure we can fit all the data we're about to read
142        // and then some, so we don't do a gazillion syscalls
143        if self.buffer.capacity() < target_size {
144            let missing = target_size - self.buffer.capacity();
145            self.buffer.reserve(missing);
146        }
147
148        let had = self.buffer.len();
149        // this is the bit we'll be reading into
150        let mut rest = self.buffer.split_off(had);
151        // this is safe because we're not extending beyond the reserved capacity
152        // and we're never reading unwritten bytes
153        let max = rest.capacity();
154        // In the original implementation, this was an unsafe operation.
155        // unsafe { rest.set_len(max) };
156        rest.resize(max, 0);
157
158        let mut buf = ReadBuf::new(&mut rest[..]);
159        ready!(Pin::new(&mut self.reader).poll_read(cx, &mut buf))?;
160        let n = buf.filled().len();
161        // adopt the new bytes
162        let read = rest.split_to(n);
163        self.buffer.unsplit(read);
164        if n == 0 {
165            return Poll::Ready(Ok(ReadResult::Eof));
166        }
167
168        Poll::Ready(Ok(ReadResult::ReceivedData))
169    }
170}