watermelon_net/connection/
streaming.rs

1use std::{
2    future::{self, Future},
3    io,
4    pin::{pin, Pin},
5    task::{Context, Poll},
6};
7
8use bytes::Buf;
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
10use watermelon_proto::proto::{
11    error::DecoderError, ClientOp, ServerOp, StreamDecoder, StreamEncoder,
12};
13
14#[derive(Debug)]
15pub struct StreamingConnection<S> {
16    socket: S,
17    encoder: StreamEncoder,
18    decoder: StreamDecoder,
19    may_flush: bool,
20}
21
22impl<S> StreamingConnection<S>
23where
24    S: AsyncRead + AsyncWrite + Unpin,
25{
26    #[must_use]
27    pub fn new(socket: S) -> Self {
28        Self {
29            socket,
30            encoder: StreamEncoder::new(),
31            decoder: StreamDecoder::new(),
32            may_flush: false,
33        }
34    }
35
36    pub fn poll_read_next(
37        &mut self,
38        cx: &mut Context<'_>,
39    ) -> Poll<Result<ServerOp, StreamingReadError>> {
40        loop {
41            match self.decoder.decode() {
42                Ok(Some(server_op)) => return Poll::Ready(Ok(server_op)),
43                Ok(None) => {}
44                Err(err) => return Poll::Ready(Err(StreamingReadError::Decoder(err))),
45            }
46
47            let read_buf_fut = pin!(self.socket.read_buf(self.decoder.read_buf()));
48            match read_buf_fut.poll(cx) {
49                Poll::Pending => return Poll::Pending,
50                Poll::Ready(Ok(1..)) => {}
51                Poll::Ready(Ok(0)) => {
52                    return Poll::Ready(Err(StreamingReadError::Io(
53                        io::ErrorKind::UnexpectedEof.into(),
54                    )))
55                }
56                Poll::Ready(Err(err)) => return Poll::Ready(Err(StreamingReadError::Io(err))),
57            }
58        }
59    }
60
61    /// Reads the next [`ServerOp`].
62    ///
63    /// # Errors
64    ///
65    /// It returns an error if the content cannot be decoded or if an I/O error occurs.
66    pub async fn read_next(&mut self) -> Result<ServerOp, StreamingReadError> {
67        future::poll_fn(|cx| self.poll_read_next(cx)).await
68    }
69
70    pub fn may_write(&self) -> bool {
71        self.encoder.has_remaining()
72    }
73
74    pub fn may_flush(&self) -> bool {
75        self.may_flush
76    }
77
78    pub fn may_enqueue_more_ops(&self) -> bool {
79        self.encoder.remaining() < 8_290_304
80    }
81
82    pub fn enqueue_write_op(&mut self, item: &ClientOp) {
83        self.encoder.enqueue_write_op(item);
84    }
85
86    pub fn poll_write_next(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
87        let remaining = self.encoder.remaining();
88        if remaining == 0 {
89            return Poll::Ready(Ok(0));
90        }
91
92        let chunk = self.encoder.chunk();
93        let write_outcome = if chunk.len() < remaining && self.socket.is_write_vectored() {
94            let mut bufs = [io::IoSlice::new(&[]); 64];
95            let n = self.encoder.chunks_vectored(&mut bufs);
96            debug_assert!(n >= 2, "perf: chunks_vectored yielded less than 2 chunks despite the apparently fragmented internal encoder representation");
97
98            Pin::new(&mut self.socket).poll_write_vectored(cx, &bufs[..n])
99        } else {
100            debug_assert!(
101                !chunk.is_empty(),
102                "perf: chunk shouldn't be empty given that `remaining > 0`"
103            );
104            Pin::new(&mut self.socket).poll_write(cx, chunk)
105        };
106
107        match write_outcome {
108            Poll::Pending => {
109                self.may_flush = false;
110                Poll::Pending
111            }
112            Poll::Ready(Ok(n)) => {
113                self.encoder.advance(n);
114                self.may_flush = true;
115                Poll::Ready(Ok(n))
116            }
117            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
118        }
119    }
120
121    /// Writes the next chunk of data to the socket.
122    ///
123    /// It returns the number of bytes that have been written.
124    ///
125    /// # Errors
126    ///
127    /// An I/O error is returned if it is not possible to write to the socket.
128    pub async fn write_next(&mut self) -> io::Result<usize> {
129        future::poll_fn(|cx| self.poll_write_next(cx)).await
130    }
131
132    pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
133        match Pin::new(&mut self.socket).poll_flush(cx) {
134            Poll::Pending => Poll::Pending,
135            Poll::Ready(Ok(())) => {
136                self.may_flush = false;
137                Poll::Ready(Ok(()))
138            }
139            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
140        }
141    }
142
143    /// Flush any buffered writes to the connection
144    ///
145    /// # Errors
146    ///
147    /// Returns an error if flushing fails
148    pub async fn flush(&mut self) -> io::Result<()> {
149        future::poll_fn(|cx| self.poll_flush(cx)).await
150    }
151
152    /// Shutdown the connection
153    ///
154    /// # Errors
155    ///
156    /// Returns an error if shutting down the connection fails.
157    /// Implementations usually ignore this error.
158    pub async fn shutdown(&mut self) -> io::Result<()> {
159        future::poll_fn(|cx| Pin::new(&mut self.socket).poll_shutdown(cx)).await
160    }
161
162    pub fn socket(&self) -> &S {
163        &self.socket
164    }
165
166    pub fn socket_mut(&mut self) -> &mut S {
167        &mut self.socket
168    }
169
170    pub fn replace_socket<F, S2>(self, replacer: F) -> StreamingConnection<S2>
171    where
172        F: FnOnce(S) -> S2,
173    {
174        StreamingConnection {
175            socket: replacer(self.socket),
176            encoder: self.encoder,
177            decoder: self.decoder,
178            may_flush: self.may_flush,
179        }
180    }
181
182    pub fn into_inner(self) -> S {
183        self.socket
184    }
185}
186
187#[derive(Debug, thiserror::Error)]
188pub enum StreamingReadError {
189    #[error("decoder")]
190    Decoder(#[source] DecoderError),
191    #[error("io")]
192    Io(#[source] io::Error),
193}
194
195#[cfg(test)]
196mod tests {
197    use std::{
198        pin::Pin,
199        task::{Context, Poll},
200    };
201
202    use claims::assert_matches;
203    use futures_util::task;
204    use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};
205    use watermelon_proto::proto::{ClientOp, ServerOp};
206
207    use super::StreamingConnection;
208
209    #[test]
210    fn ping_pong() {
211        let waker = task::noop_waker();
212        let mut cx = Context::from_waker(&waker);
213
214        let (socket, mut conn) = io::duplex(1024);
215
216        let mut client = StreamingConnection::new(socket);
217
218        // Initial state is ok
219        assert!(client.poll_read_next(&mut cx).is_pending());
220        assert_matches!(client.poll_write_next(&mut cx), Poll::Ready(Ok(0)));
221
222        let mut buf = [0; 1024];
223        let mut read_buf = ReadBuf::new(&mut buf);
224        assert!(Pin::new(&mut conn)
225            .poll_read(&mut cx, &mut read_buf)
226            .is_pending());
227
228        // Write PING and verify it was received
229        client.enqueue_write_op(&ClientOp::Ping);
230        assert_matches!(client.poll_write_next(&mut cx), Poll::Ready(Ok(6)));
231        assert_matches!(
232            Pin::new(&mut conn).poll_read(&mut cx, &mut read_buf),
233            Poll::Ready(Ok(()))
234        );
235        assert_eq!(read_buf.filled(), b"PING\r\n");
236
237        // Receive PONG
238        assert_matches!(
239            Pin::new(&mut conn).poll_write(&mut cx, b"PONG\r\n"),
240            Poll::Ready(Ok(6))
241        );
242        assert_matches!(
243            client.poll_read_next(&mut cx),
244            Poll::Ready(Ok(ServerOp::Pong))
245        );
246        assert!(client.poll_read_next(&mut cx).is_pending());
247    }
248}