watermelon_net/connection/
streaming.rs1use std::{
2 future::{self, Future},
3 io,
4 pin::{Pin, pin},
5 task::{Context, Poll},
6};
7
8use bytes::Buf;
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
10use watermelon_proto::proto::{
11 ClientOp, ServerOp, StreamDecoder, StreamEncoder, error::DecoderError,
12};
13
14#[derive(Debug)]
15pub struct StreamingConnection<S> {
16 socket: S,
17 encoder: StreamEncoder,
18 decoder: StreamDecoder,
19 may_flush: bool,
20}
21
22impl<S> StreamingConnection<S>
23where
24 S: AsyncRead + AsyncWrite + Unpin,
25{
26 #[must_use]
27 pub fn new(socket: S) -> Self {
28 Self {
29 socket,
30 encoder: StreamEncoder::new(),
31 decoder: StreamDecoder::new(),
32 may_flush: false,
33 }
34 }
35
36 pub fn poll_read_next(
37 &mut self,
38 cx: &mut Context<'_>,
39 ) -> Poll<Result<ServerOp, StreamingReadError>> {
40 loop {
41 match self.decoder.decode() {
42 Ok(Some(server_op)) => return Poll::Ready(Ok(server_op)),
43 Ok(None) => {}
44 Err(err) => return Poll::Ready(Err(StreamingReadError::Decoder(err))),
45 }
46
47 let read_buf_fut = pin!(self.socket.read_buf(self.decoder.read_buf()));
48 match read_buf_fut.poll(cx) {
49 Poll::Pending => return Poll::Pending,
50 Poll::Ready(Ok(1..)) => {}
51 Poll::Ready(Ok(0)) => {
52 return Poll::Ready(Err(StreamingReadError::Io(
53 io::ErrorKind::UnexpectedEof.into(),
54 )));
55 }
56 Poll::Ready(Err(err)) => return Poll::Ready(Err(StreamingReadError::Io(err))),
57 }
58 }
59 }
60
61 pub async fn read_next(&mut self) -> Result<ServerOp, StreamingReadError> {
67 future::poll_fn(|cx| self.poll_read_next(cx)).await
68 }
69
70 pub fn may_write(&self) -> bool {
71 self.encoder.has_remaining()
72 }
73
74 pub fn may_flush(&self) -> bool {
75 self.may_flush
76 }
77
78 pub fn may_enqueue_more_ops(&self) -> bool {
79 self.encoder.remaining() < 8_290_304
80 }
81
82 pub fn enqueue_write_op(&mut self, item: &ClientOp) {
83 self.encoder.enqueue_write_op(item);
84 }
85
86 pub fn poll_write_next(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
87 let remaining = self.encoder.remaining();
88 if remaining == 0 {
89 return Poll::Ready(Ok(0));
90 }
91
92 let chunk = self.encoder.chunk();
93 let write_outcome = if chunk.len() < remaining && self.socket.is_write_vectored() {
94 let mut bufs = [io::IoSlice::new(&[]); 64];
95 let n = self.encoder.chunks_vectored(&mut bufs);
96 debug_assert!(
97 n >= 2,
98 "perf: chunks_vectored yielded less than 2 chunks despite the apparently fragmented internal encoder representation"
99 );
100
101 Pin::new(&mut self.socket).poll_write_vectored(cx, &bufs[..n])
102 } else {
103 debug_assert!(
104 !chunk.is_empty(),
105 "perf: chunk shouldn't be empty given that `remaining > 0`"
106 );
107 Pin::new(&mut self.socket).poll_write(cx, chunk)
108 };
109
110 match write_outcome {
111 Poll::Pending => {
112 self.may_flush = false;
113 Poll::Pending
114 }
115 Poll::Ready(Ok(n)) => {
116 self.encoder.advance(n);
117 self.may_flush = true;
118 Poll::Ready(Ok(n))
119 }
120 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
121 }
122 }
123
124 pub async fn write_next(&mut self) -> io::Result<usize> {
132 future::poll_fn(|cx| self.poll_write_next(cx)).await
133 }
134
135 pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
136 match Pin::new(&mut self.socket).poll_flush(cx) {
137 Poll::Pending => Poll::Pending,
138 Poll::Ready(Ok(())) => {
139 self.may_flush = false;
140 Poll::Ready(Ok(()))
141 }
142 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
143 }
144 }
145
146 pub async fn flush(&mut self) -> io::Result<()> {
152 future::poll_fn(|cx| self.poll_flush(cx)).await
153 }
154
155 pub async fn shutdown(&mut self) -> io::Result<()> {
162 future::poll_fn(|cx| Pin::new(&mut self.socket).poll_shutdown(cx)).await
163 }
164
165 pub fn socket(&self) -> &S {
166 &self.socket
167 }
168
169 pub fn socket_mut(&mut self) -> &mut S {
170 &mut self.socket
171 }
172
173 pub fn replace_socket<F, S2>(self, replacer: F) -> StreamingConnection<S2>
174 where
175 F: FnOnce(S) -> S2,
176 {
177 StreamingConnection {
178 socket: replacer(self.socket),
179 encoder: self.encoder,
180 decoder: self.decoder,
181 may_flush: self.may_flush,
182 }
183 }
184
185 pub fn into_inner(self) -> S {
186 self.socket
187 }
188}
189
190#[derive(Debug, thiserror::Error)]
191pub enum StreamingReadError {
192 #[error("decoder")]
193 Decoder(#[source] DecoderError),
194 #[error("io")]
195 Io(#[source] io::Error),
196}
197
198#[cfg(test)]
199mod tests {
200 use std::{
201 pin::Pin,
202 task::{Context, Poll},
203 };
204
205 use claims::assert_matches;
206 use futures_util::task;
207 use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};
208 use watermelon_proto::proto::{ClientOp, ServerOp};
209
210 use super::StreamingConnection;
211
212 #[test]
213 fn ping_pong() {
214 let waker = task::noop_waker();
215 let mut cx = Context::from_waker(&waker);
216
217 let (socket, mut conn) = io::duplex(1024);
218
219 let mut client = StreamingConnection::new(socket);
220
221 assert!(client.poll_read_next(&mut cx).is_pending());
223 assert_matches!(client.poll_write_next(&mut cx), Poll::Ready(Ok(0)));
224
225 let mut buf = [0; 1024];
226 let mut read_buf = ReadBuf::new(&mut buf);
227 assert!(
228 Pin::new(&mut conn)
229 .poll_read(&mut cx, &mut read_buf)
230 .is_pending()
231 );
232
233 client.enqueue_write_op(&ClientOp::Ping);
235 assert_matches!(client.poll_write_next(&mut cx), Poll::Ready(Ok(6)));
236 assert_matches!(
237 Pin::new(&mut conn).poll_read(&mut cx, &mut read_buf),
238 Poll::Ready(Ok(()))
239 );
240 assert_eq!(read_buf.filled(), b"PING\r\n");
241
242 assert_matches!(
244 Pin::new(&mut conn).poll_write(&mut cx, b"PONG\r\n"),
245 Poll::Ready(Ok(6))
246 );
247 assert_matches!(
248 client.poll_read_next(&mut cx),
249 Poll::Ready(Ok(ServerOp::Pong))
250 );
251 assert!(client.poll_read_next(&mut cx).is_pending());
252 }
253}