1use std::{
2 io,
3 marker::PhantomData,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use bytes::{Buf, BytesMut};
9use futures_core::{ready, Stream};
10use ordered_varint::Variable;
11use tokio::io::{AsyncRead, ReadBuf};
12use transmog::OwnedDeserializer;
13
14#[derive(Debug)]
23pub struct TransmogReader<R, T, F> {
24 format: F,
25 reader: R,
26 pub(crate) buffer: BytesMut,
27 into: PhantomData<T>,
28}
29
30impl<R, T, F> Unpin for TransmogReader<R, T, F> where R: Unpin {}
31
32impl<R, T, F> TransmogReader<R, T, F> {
33 pub fn get_ref(&self) -> &R {
37 &self.reader
38 }
39
40 pub fn get_mut(&mut self) -> &mut R {
44 &mut self.reader
45 }
46
47 pub fn buffer(&self) -> &[u8] {
51 &self.buffer[..]
52 }
53
54 pub fn into_inner(self) -> R {
58 self.reader
59 }
60}
61
62impl<R, T, F> TransmogReader<R, T, F> {
63 pub fn new(reader: R, format: F) -> Self {
65 TransmogReader {
66 format,
67 buffer: BytesMut::with_capacity(8192),
68 reader,
69 into: PhantomData,
70 }
71 }
72
73 pub fn default_for(format: F) -> Self
75 where
76 R: Default,
77 {
78 Self::new(R::default(), format)
79 }
80}
81
82impl<R, T, F> Stream for TransmogReader<R, T, F>
83where
84 R: AsyncRead + Unpin,
85 F: OwnedDeserializer<T>,
86{
87 type Item = Result<T, F::Error>;
88 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
89 loop {
90 let fill_result = ready!(self
91 .as_mut()
92 .fill(cx, 9)
93 .map_err(<F::Error as From<std::io::Error>>::from))?;
94
95 let mut buf_reader = &self.buffer[..];
96 let buffer_start = buf_reader.as_ptr() as usize;
97 if let Ok(message_size) = u64::decode_variable(&mut buf_reader) {
98 let header_len = buf_reader.as_ptr() as usize - buffer_start;
99 let target_buffer_size = usize::try_from(message_size).unwrap() + header_len;
100
101 ready!(self
102 .as_mut()
103 .fill(cx, target_buffer_size)
104 .map_err(<F::Error as From<std::io::Error>>::from))?;
105
106 if self.buffer.len() >= target_buffer_size {
107 let message = self
108 .format
109 .deserialize_owned(&self.buffer[header_len..target_buffer_size])
110 .unwrap();
111 self.buffer.advance(target_buffer_size);
112 break Poll::Ready(Some(Ok(message)));
113 }
114 } else if let ReadResult::Eof = fill_result {
115 break Poll::Ready(None);
116 }
117 }
118 }
119}
120
121#[derive(Debug)]
122enum ReadResult {
123 ReceivedData,
124 Eof,
125}
126
127impl<R, T, F> TransmogReader<R, T, F>
128where
129 R: AsyncRead + Unpin,
130{
131 fn fill(
132 mut self: Pin<&mut Self>,
133 cx: &mut Context<'_>,
134 target_size: usize,
135 ) -> Poll<Result<ReadResult, io::Error>> {
136 if self.buffer.len() >= target_size {
137 return Poll::Ready(Ok(ReadResult::ReceivedData));
139 }
140
141 if self.buffer.capacity() < target_size {
144 let missing = target_size - self.buffer.capacity();
145 self.buffer.reserve(missing);
146 }
147
148 let had = self.buffer.len();
149 let mut rest = self.buffer.split_off(had);
151 let max = rest.capacity();
154 rest.resize(max, 0);
157
158 let mut buf = ReadBuf::new(&mut rest[..]);
159 ready!(Pin::new(&mut self.reader).poll_read(cx, &mut buf))?;
160 let n = buf.filled().len();
161 let read = rest.split_to(n);
163 self.buffer.unsplit(read);
164 if n == 0 {
165 return Poll::Ready(Ok(ReadResult::Eof));
166 }
167
168 Poll::Ready(Ok(ReadResult::ReceivedData))
169 }
170}