watermelon_net/connection/
websocket.rs

1use std::{
2    future, io,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use bytes::Bytes;
8use futures_core::Stream as _;
9use futures_sink::Sink;
10use futures_util::task::noop_waker_ref;
11use http::Uri;
12use tokio::io::{AsyncRead, AsyncWrite};
13use tokio_websockets::{ClientBuilder, Message, WebSocketStream};
14use watermelon_proto::proto::{
15    decode_frame, error::FrameDecoderError, ClientOp, FramedEncoder, ServerOp,
16};
17
18#[derive(Debug)]
19pub struct WebsocketConnection<S> {
20    socket: WebSocketStream<S>,
21    encoder: FramedEncoder,
22    residual_frame: Bytes,
23    should_flush: bool,
24}
25
26impl<S> WebsocketConnection<S>
27where
28    S: AsyncRead + AsyncWrite + Unpin,
29{
30    /// Construct a websocket stream to a pre-established connection `socket`.
31    ///
32    /// # Errors
33    ///
34    /// Returns an error if the websocket handshake fails.
35    pub async fn new(uri: Uri, socket: S) -> io::Result<Self> {
36        let (socket, _resp) = ClientBuilder::from_uri(uri)
37            .connect_on(socket)
38            .await
39            .map_err(websockets_error_to_io)?;
40        Ok(Self {
41            socket,
42            encoder: FramedEncoder::new(),
43            residual_frame: Bytes::new(),
44            should_flush: false,
45        })
46    }
47
48    pub fn poll_read_next(
49        &mut self,
50        cx: &mut Context<'_>,
51    ) -> Poll<Result<ServerOp, WebsocketReadError>> {
52        loop {
53            if !self.residual_frame.is_empty() {
54                return Poll::Ready(
55                    decode_frame(&mut self.residual_frame).map_err(WebsocketReadError::Decoder),
56                );
57            }
58
59            match Pin::new(&mut self.socket).poll_next(cx) {
60                Poll::Pending => return Poll::Pending,
61                Poll::Ready(Some(Ok(message))) if message.is_binary() => {
62                    self.residual_frame = message.into_payload().into();
63                }
64                Poll::Ready(Some(Ok(_message))) => {}
65                Poll::Ready(Some(Err(err))) => {
66                    return Poll::Ready(Err(WebsocketReadError::Io(websockets_error_to_io(err))))
67                }
68                Poll::Ready(None) => return Poll::Ready(Err(WebsocketReadError::Closed)),
69            }
70        }
71    }
72
73    /// Reads the next [`ServerOp`].
74    ///
75    /// # Errors
76    ///
77    /// It returns an error if the content cannot be decoded or if an I/O error occurs.
78    pub async fn read_next(&mut self) -> Result<ServerOp, WebsocketReadError> {
79        future::poll_fn(|cx| self.poll_read_next(cx)).await
80    }
81
82    pub fn should_flush(&self) -> bool {
83        self.should_flush
84    }
85
86    pub fn may_enqueue_more_ops(&mut self) -> bool {
87        // TODO: switch to `std::task::Waker::noop` with MSRV >= 1.85
88        let mut cx = Context::from_waker(noop_waker_ref());
89        Pin::new(&mut self.socket).poll_ready(&mut cx).is_ready()
90    }
91
92    /// Enqueue `item` to be written.
93    #[expect(clippy::missing_panics_doc)]
94    pub fn enqueue_write_op(&mut self, item: &ClientOp) {
95        let payload = self.encoder.encode(item);
96        Pin::new(&mut self.socket)
97            .start_send(Message::binary(payload))
98            .unwrap();
99        self.should_flush = true;
100    }
101
102    pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
103        Pin::new(&mut self.socket)
104            .poll_flush(cx)
105            .map_err(websockets_error_to_io)
106    }
107
108    /// Flush any buffered writes to the connection
109    ///
110    /// # Errors
111    ///
112    /// Returns an error if flushing fails
113    pub async fn flush(&mut self) -> io::Result<()> {
114        future::poll_fn(|cx| self.poll_flush(cx)).await
115    }
116
117    /// Shutdown the connection
118    ///
119    /// # Errors
120    ///
121    /// Returns an error if shutting down the connection fails.
122    /// Implementations usually ignore this error.
123    pub async fn shutdown(&mut self) -> io::Result<()> {
124        future::poll_fn(|cx| Pin::new(&mut self.socket).poll_close(cx))
125            .await
126            .map_err(websockets_error_to_io)
127    }
128}
129
130#[derive(Debug, thiserror::Error)]
131pub enum WebsocketReadError {
132    #[error("decoder")]
133    Decoder(#[source] FrameDecoderError),
134    #[error("io")]
135    Io(#[source] io::Error),
136    #[error("closed")]
137    Closed,
138}
139
140fn websockets_error_to_io(err: tokio_websockets::Error) -> io::Error {
141    match err {
142        tokio_websockets::Error::Io(err) => err,
143        err => io::Error::other(err),
144    }
145}