watermelon_net/connection/
streaming.rs1use std::{
2 future::{self, Future},
3 io,
4 pin::{pin, Pin},
5 task::{Context, Poll},
6};
7
8use bytes::Buf;
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
10use watermelon_proto::proto::{
11 error::DecoderError, ClientOp, ServerOp, StreamDecoder, StreamEncoder,
12};
13
14#[derive(Debug)]
15pub struct StreamingConnection<S> {
16 socket: S,
17 encoder: StreamEncoder,
18 decoder: StreamDecoder,
19 may_flush: bool,
20}
21
22impl<S> StreamingConnection<S>
23where
24 S: AsyncRead + AsyncWrite + Unpin,
25{
26 #[must_use]
27 pub fn new(socket: S) -> Self {
28 Self {
29 socket,
30 encoder: StreamEncoder::new(),
31 decoder: StreamDecoder::new(),
32 may_flush: false,
33 }
34 }
35
36 pub fn poll_read_next(
37 &mut self,
38 cx: &mut Context<'_>,
39 ) -> Poll<Result<ServerOp, StreamingReadError>> {
40 loop {
41 match self.decoder.decode() {
42 Ok(Some(server_op)) => return Poll::Ready(Ok(server_op)),
43 Ok(None) => {}
44 Err(err) => return Poll::Ready(Err(StreamingReadError::Decoder(err))),
45 }
46
47 let read_buf_fut = pin!(self.socket.read_buf(self.decoder.read_buf()));
48 match read_buf_fut.poll(cx) {
49 Poll::Pending => return Poll::Pending,
50 Poll::Ready(Ok(1..)) => {}
51 Poll::Ready(Ok(0)) => {
52 return Poll::Ready(Err(StreamingReadError::Io(
53 io::ErrorKind::UnexpectedEof.into(),
54 )))
55 }
56 Poll::Ready(Err(err)) => return Poll::Ready(Err(StreamingReadError::Io(err))),
57 }
58 }
59 }
60
61 pub async fn read_next(&mut self) -> Result<ServerOp, StreamingReadError> {
67 future::poll_fn(|cx| self.poll_read_next(cx)).await
68 }
69
70 pub fn may_write(&self) -> bool {
71 self.encoder.has_remaining()
72 }
73
74 pub fn may_flush(&self) -> bool {
75 self.may_flush
76 }
77
78 pub fn may_enqueue_more_ops(&self) -> bool {
79 self.encoder.remaining() < 8_290_304
80 }
81
82 pub fn enqueue_write_op(&mut self, item: &ClientOp) {
83 self.encoder.enqueue_write_op(item);
84 }
85
86 pub fn poll_write_next(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
87 let remaining = self.encoder.remaining();
88 if remaining == 0 {
89 return Poll::Ready(Ok(0));
90 }
91
92 let chunk = self.encoder.chunk();
93 let write_outcome = if chunk.len() < remaining && self.socket.is_write_vectored() {
94 let mut bufs = [io::IoSlice::new(&[]); 64];
95 let n = self.encoder.chunks_vectored(&mut bufs);
96 debug_assert!(n >= 2, "perf: chunks_vectored yielded less than 2 chunks despite the apparently fragmented internal encoder representation");
97
98 Pin::new(&mut self.socket).poll_write_vectored(cx, &bufs[..n])
99 } else {
100 debug_assert!(
101 !chunk.is_empty(),
102 "perf: chunk shouldn't be empty given that `remaining > 0`"
103 );
104 Pin::new(&mut self.socket).poll_write(cx, chunk)
105 };
106
107 match write_outcome {
108 Poll::Pending => {
109 self.may_flush = false;
110 Poll::Pending
111 }
112 Poll::Ready(Ok(n)) => {
113 self.encoder.advance(n);
114 self.may_flush = true;
115 Poll::Ready(Ok(n))
116 }
117 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
118 }
119 }
120
121 pub async fn write_next(&mut self) -> io::Result<usize> {
129 future::poll_fn(|cx| self.poll_write_next(cx)).await
130 }
131
132 pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
133 match Pin::new(&mut self.socket).poll_flush(cx) {
134 Poll::Pending => Poll::Pending,
135 Poll::Ready(Ok(())) => {
136 self.may_flush = false;
137 Poll::Ready(Ok(()))
138 }
139 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
140 }
141 }
142
143 pub async fn flush(&mut self) -> io::Result<()> {
149 future::poll_fn(|cx| self.poll_flush(cx)).await
150 }
151
152 pub async fn shutdown(&mut self) -> io::Result<()> {
159 future::poll_fn(|cx| Pin::new(&mut self.socket).poll_shutdown(cx)).await
160 }
161
162 pub fn socket(&self) -> &S {
163 &self.socket
164 }
165
166 pub fn socket_mut(&mut self) -> &mut S {
167 &mut self.socket
168 }
169
170 pub fn replace_socket<F, S2>(self, replacer: F) -> StreamingConnection<S2>
171 where
172 F: FnOnce(S) -> S2,
173 {
174 StreamingConnection {
175 socket: replacer(self.socket),
176 encoder: self.encoder,
177 decoder: self.decoder,
178 may_flush: self.may_flush,
179 }
180 }
181
182 pub fn into_inner(self) -> S {
183 self.socket
184 }
185}
186
187#[derive(Debug, thiserror::Error)]
188pub enum StreamingReadError {
189 #[error("decoder")]
190 Decoder(#[source] DecoderError),
191 #[error("io")]
192 Io(#[source] io::Error),
193}
194
195#[cfg(test)]
196mod tests {
197 use std::{
198 pin::Pin,
199 task::{Context, Poll},
200 };
201
202 use claims::assert_matches;
203 use futures_util::task;
204 use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};
205 use watermelon_proto::proto::{ClientOp, ServerOp};
206
207 use super::StreamingConnection;
208
209 #[test]
210 fn ping_pong() {
211 let waker = task::noop_waker();
212 let mut cx = Context::from_waker(&waker);
213
214 let (socket, mut conn) = io::duplex(1024);
215
216 let mut client = StreamingConnection::new(socket);
217
218 assert!(client.poll_read_next(&mut cx).is_pending());
220 assert_matches!(client.poll_write_next(&mut cx), Poll::Ready(Ok(0)));
221
222 let mut buf = [0; 1024];
223 let mut read_buf = ReadBuf::new(&mut buf);
224 assert!(Pin::new(&mut conn)
225 .poll_read(&mut cx, &mut read_buf)
226 .is_pending());
227
228 client.enqueue_write_op(&ClientOp::Ping);
230 assert_matches!(client.poll_write_next(&mut cx), Poll::Ready(Ok(6)));
231 assert_matches!(
232 Pin::new(&mut conn).poll_read(&mut cx, &mut read_buf),
233 Poll::Ready(Ok(()))
234 );
235 assert_eq!(read_buf.filled(), b"PING\r\n");
236
237 assert_matches!(
239 Pin::new(&mut conn).poll_write(&mut cx, b"PONG\r\n"),
240 Poll::Ready(Ok(6))
241 );
242 assert_matches!(
243 client.poll_read_next(&mut cx),
244 Poll::Ready(Ok(ServerOp::Pong))
245 );
246 assert!(client.poll_read_next(&mut cx).is_pending());
247 }
248}