Skip to main content

vox_websocket/
native.rs

1//! Native (tokio-tungstenite) WebSocket transport implementing [`Link`].
2
3use std::io;
4
5use futures_util::{SinkExt, StreamExt};
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio::sync::mpsc;
8use tokio::task::JoinHandle;
9use tokio_tungstenite::WebSocketStream;
10use tokio_tungstenite::tungstenite::protocol::Message as WsMessage;
11
12use vox_types::{Backing, Link, LinkRx, LinkTx, LinkTxPermit, WriteSlot};
13
14/// A [`Link`](vox_types::Link) over a WebSocket connection.
15///
16/// Wraps a [`WebSocketStream`] and sends each vox payload as a single
17/// binary WebSocket frame. The WebSocket protocol preserves message
18/// boundaries natively, so no length-prefix framing is needed.
19// r[impl transport.websocket]
20// r[impl transport.websocket.platforms]
21// r[impl zerocopy.framing.link.websocket]
22pub struct WsLink<S> {
23    stream: WebSocketStream<S>,
24}
25
26impl<S> WsLink<S> {
27    /// Construct from an already-upgraded [`WebSocketStream`].
28    pub fn new(stream: WebSocketStream<S>) -> Self {
29        Self { stream }
30    }
31}
32
33impl<S> WsLink<S>
34where
35    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
36{
37    /// Accept a server-side WebSocket handshake over a raw stream.
38    pub async fn server(stream: S) -> Result<Self, io::Error> {
39        let ws = tokio_tungstenite::accept_async(stream)
40            .await
41            .map_err(|e| io::Error::other(e.to_string()))?;
42        Ok(Self::new(ws))
43    }
44}
45
46impl<S> Link for WsLink<S>
47where
48    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
49{
50    type Tx = WsLinkTx;
51    type Rx = WsLinkRx;
52
53    fn split(self) -> (Self::Tx, Self::Rx) {
54        let (tx_out, rx_out) = mpsc::channel::<Vec<u8>>(1);
55        let (tx_in, rx_in) = mpsc::channel::<Result<WsMessage, io::Error>>(1);
56
57        let io_task = tokio::spawn(io_loop(self.stream, rx_out, tx_in));
58
59        (
60            WsLinkTx {
61                tx: tx_out,
62                io_task,
63            },
64            WsLinkRx { rx: rx_in },
65        )
66    }
67}
68
69/// Background I/O task that owns the WebSocketStream.
70///
71/// Multiplexes outbound writes (from the mpsc channel) with inbound reads
72/// (forwarded to the read channel). When the write channel closes, the
73/// entire WebSocket stream is dropped, causing the read side to see EOF.
74async fn io_loop<S>(
75    mut ws: WebSocketStream<S>,
76    mut rx_out: mpsc::Receiver<Vec<u8>>,
77    tx_in: mpsc::Sender<Result<WsMessage, io::Error>>,
78) where
79    S: AsyncRead + AsyncWrite + Unpin,
80{
81    loop {
82        tokio::select! {
83            // Outbound: drain the write channel and send as binary frames.
84            msg = rx_out.recv() => {
85                match msg {
86                    Some(bytes) => {
87                        if let Err(e) = ws.feed(WsMessage::binary(bytes)).await {
88                            let _ = tx_in.send(Err(io::Error::other(e.to_string()))).await;
89                            return;
90                        }
91                        // Coalesce: drain any already-queued messages before flushing.
92                        while let Ok(bytes) = rx_out.try_recv() {
93                            if let Err(e) = ws.feed(WsMessage::binary(bytes)).await {
94                                let _ = tx_in.send(Err(io::Error::other(e.to_string()))).await;
95                                return;
96                            }
97                        }
98                        if let Err(e) = ws.flush().await {
99                            let _ = tx_in.send(Err(io::Error::other(e.to_string()))).await;
100                            return;
101                        }
102                    }
103                    None => {
104                        // Write channel closed — drop the WebSocket stream.
105                        // This closes the underlying transport, causing the
106                        // peer's read side to see EOF.
107                        return;
108                    }
109                }
110            }
111            // Inbound: read from the WebSocket and forward to the read channel.
112            frame = ws.next() => {
113                match frame {
114                    Some(Ok(msg)) => {
115                        if tx_in.send(Ok(msg)).await.is_err() {
116                            // Reader dropped — shut down.
117                            return;
118                        }
119                    }
120                    Some(Err(e)) => {
121                        use tokio_tungstenite::tungstenite::error::ProtocolError;
122                        use tokio_tungstenite::tungstenite::Error as WsError;
123                        match &e {
124                            // The peer dropped the connection without a close
125                            // handshake — this is just EOF for our purposes.
126                            WsError::Protocol(
127                                ProtocolError::ResetWithoutClosingHandshake,
128                            ) => {
129                                return;
130                            }
131                            _ => {
132                                let _ = tx_in.send(Err(io::Error::other(e.to_string()))).await;
133                                return;
134                            }
135                        }
136                    }
137                    None => {
138                        // WebSocket stream ended.
139                        return;
140                    }
141                }
142            }
143        }
144    }
145}
146
147// ---------------------------------------------------------------------------
148// Tx
149// ---------------------------------------------------------------------------
150
151/// Sending half of a [`WsLink`].
152///
153/// Internally uses a bounded mpsc channel (capacity 1) to serialize writes
154/// and provide backpressure. The I/O task drains the channel and sends
155/// binary WebSocket frames.
156pub struct WsLinkTx {
157    tx: mpsc::Sender<Vec<u8>>,
158    io_task: JoinHandle<()>,
159}
160
161/// Permit for sending one payload through a [`WsLinkTx`].
162pub struct WsLinkTxPermit {
163    permit: mpsc::OwnedPermit<Vec<u8>>,
164}
165
166/// Write slot for [`WsLinkTx`].
167pub struct WsWriteSlot {
168    buf: Vec<u8>,
169    permit: mpsc::OwnedPermit<Vec<u8>>,
170}
171
172impl LinkTx for WsLinkTx {
173    type Permit = WsLinkTxPermit;
174
175    async fn reserve(&self) -> io::Result<Self::Permit> {
176        let permit = self.tx.clone().reserve_owned().await.map_err(|_| {
177            io::Error::new(
178                io::ErrorKind::ConnectionReset,
179                "websocket writer task stopped",
180            )
181        })?;
182        Ok(WsLinkTxPermit { permit })
183    }
184
185    async fn close(self) -> io::Result<()> {
186        drop(self.tx);
187        self.io_task.await.map_err(io::Error::other)
188    }
189}
190
191// r[impl zerocopy.send.websocket]
192impl LinkTxPermit for WsLinkTxPermit {
193    type Slot = WsWriteSlot;
194
195    fn alloc(self, len: usize) -> io::Result<Self::Slot> {
196        Ok(WsWriteSlot {
197            buf: vec![0u8; len],
198            permit: self.permit,
199        })
200    }
201}
202
203impl WriteSlot for WsWriteSlot {
204    fn as_mut_slice(&mut self) -> &mut [u8] {
205        &mut self.buf
206    }
207
208    fn commit(self) {
209        drop(self.permit.send(self.buf));
210    }
211}
212
213// ---------------------------------------------------------------------------
214// Rx
215// ---------------------------------------------------------------------------
216
217/// Receiving half of a [`WsLink`].
218pub struct WsLinkRx {
219    rx: mpsc::Receiver<Result<WsMessage, io::Error>>,
220}
221
222/// Error type for [`WsLinkRx`].
223#[derive(Debug)]
224pub struct WsLinkRxError(io::Error);
225
226impl std::fmt::Display for WsLinkRxError {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        write!(f, "websocket rx: {}", self.0)
229    }
230}
231
232impl std::error::Error for WsLinkRxError {
233    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
234        Some(&self.0)
235    }
236}
237
238// r[impl zerocopy.recv.websocket]
239impl LinkRx for WsLinkRx {
240    type Error = WsLinkRxError;
241
242    async fn recv(&mut self) -> Result<Option<Backing>, Self::Error> {
243        loop {
244            match self.rx.recv().await {
245                Some(Ok(WsMessage::Binary(data))) => {
246                    return Ok(Some(Backing::Boxed(Vec::from(data).into_boxed_slice())));
247                }
248                Some(Ok(WsMessage::Close(_))) | None => {
249                    return Ok(None);
250                }
251                Some(Ok(WsMessage::Ping(_) | WsMessage::Pong(_) | WsMessage::Frame(_))) => {
252                    continue;
253                }
254                Some(Ok(WsMessage::Text(_))) => {
255                    return Err(WsLinkRxError(io::Error::new(
256                        io::ErrorKind::InvalidData,
257                        "text frames not allowed on vox websocket link",
258                    )));
259                }
260                Some(Err(e)) => {
261                    return Err(WsLinkRxError(e));
262                }
263            }
264        }
265    }
266}
267
268// ---------------------------------------------------------------------------
269// Tests
270// ---------------------------------------------------------------------------
271
272#[cfg(test)]
273mod tests {
274    use tokio_tungstenite::WebSocketStream;
275    use tokio_tungstenite::tungstenite::protocol::Role;
276    use vox_types::{Backing, Link, LinkRx, LinkTx, LinkTxPermit, WriteSlot};
277
278    use super::*;
279
280    type TestWsLink = WsLink<tokio::io::DuplexStream>;
281
282    /// Create a connected pair of WsLinks backed by a tokio duplex pipe.
283    async fn ws_pair() -> (TestWsLink, TestWsLink) {
284        let (a, b) = tokio::io::duplex(64 * 1024);
285        let ws_a = WebSocketStream::from_raw_socket(a, Role::Server, None).await;
286        let ws_b = WebSocketStream::from_raw_socket(b, Role::Client, None).await;
287        (WsLink::new(ws_a), WsLink::new(ws_b))
288    }
289
290    fn payload(backing: &Backing) -> &[u8] {
291        match backing {
292            Backing::Boxed(b) => b,
293            Backing::Shared(s) => s.as_bytes(),
294        }
295    }
296
297    #[tokio::test]
298    async fn round_trip_single() {
299        let (a, b) = ws_pair().await;
300        let (tx_a, _rx_a) = a.split();
301        let (_tx_b, mut rx_b) = b.split();
302
303        let permit = tx_a.reserve().await.unwrap();
304        let mut slot = permit.alloc(5).unwrap();
305        slot.as_mut_slice().copy_from_slice(b"hello");
306        slot.commit();
307
308        let msg = rx_b.recv().await.unwrap().unwrap();
309        assert_eq!(payload(&msg), b"hello");
310    }
311
312    #[tokio::test]
313    async fn multiple_messages_in_order() {
314        let (a, b) = ws_pair().await;
315        let (tx_a, _rx_a) = a.split();
316        let (_tx_b, mut rx_b) = b.split();
317
318        let payloads: &[&[u8]] = &[b"one", b"two", b"three", b"four"];
319        for p in payloads {
320            let permit = tx_a.reserve().await.unwrap();
321            let mut slot = permit.alloc(p.len()).unwrap();
322            slot.as_mut_slice().copy_from_slice(p);
323            slot.commit();
324        }
325
326        for expected in payloads {
327            let msg = rx_b.recv().await.unwrap().unwrap();
328            assert_eq!(payload(&msg), *expected);
329        }
330    }
331
332    // r[verify link.message.empty]
333    #[tokio::test]
334    async fn empty_payload() {
335        let (a, b) = ws_pair().await;
336        let (tx_a, _rx_a) = a.split();
337        let (_tx_b, mut rx_b) = b.split();
338
339        let permit = tx_a.reserve().await.unwrap();
340        let slot = permit.alloc(0).unwrap();
341        slot.commit();
342
343        let msg = rx_b.recv().await.unwrap().unwrap();
344        assert_eq!(payload(&msg), b"");
345    }
346
347    // r[verify link.rx.eof]
348    #[tokio::test]
349    async fn eof_on_peer_close() {
350        let (a, b) = ws_pair().await;
351        let (tx_a, _rx_a) = a.split();
352        let (_tx_b, mut rx_b) = b.split();
353
354        tx_a.close().await.unwrap();
355
356        assert!(rx_b.recv().await.unwrap().is_none());
357        // Subsequent calls also return None
358        assert!(rx_b.recv().await.unwrap().is_none());
359    }
360
361    // r[verify link.tx.permit.drop]
362    #[tokio::test]
363    async fn dropped_permit_sends_nothing() {
364        let (a, b) = ws_pair().await;
365        let (tx_a, _rx_a) = a.split();
366        let (_tx_b, mut rx_b) = b.split();
367
368        // Drop permit without allocating — nothing should be sent
369        let permit = tx_a.reserve().await.unwrap();
370        drop(permit);
371
372        // Then send a real message
373        let permit = tx_a.reserve().await.unwrap();
374        let mut slot = permit.alloc(3).unwrap();
375        slot.as_mut_slice().copy_from_slice(b"yep");
376        slot.commit();
377
378        let msg = rx_b.recv().await.unwrap().unwrap();
379        assert_eq!(payload(&msg), b"yep");
380    }
381
382    // r[verify link.tx.discard]
383    #[tokio::test]
384    async fn dropped_slot_sends_nothing() {
385        let (a, b) = ws_pair().await;
386        let (tx_a, _rx_a) = a.split();
387        let (_tx_b, mut rx_b) = b.split();
388
389        // Drop slot without committing — nothing should be sent
390        let permit = tx_a.reserve().await.unwrap();
391        let slot = permit.alloc(3).unwrap();
392        drop(slot);
393
394        // Then send a real message
395        let permit = tx_a.reserve().await.unwrap();
396        let mut slot = permit.alloc(2).unwrap();
397        slot.as_mut_slice().copy_from_slice(b"ok");
398        slot.commit();
399
400        let msg = rx_b.recv().await.unwrap().unwrap();
401        assert_eq!(payload(&msg), b"ok");
402    }
403}