Skip to main content

ustreamer_transport/
websocket.rs

1//! WebSocket transport fallback for browsers without WebTransport support.
2
3use std::net::SocketAddr;
4use std::sync::{Arc, Mutex as StdMutex};
5
6use futures_util::stream::{SplitSink, SplitStream};
7use futures_util::{SinkExt, StreamExt};
8use tokio::net::{TcpListener, TcpStream};
9use tokio::sync::Mutex;
10use tokio_tungstenite::tungstenite::Message;
11use tokio_tungstenite::tungstenite::handshake::server::{Request, Response};
12use tokio_tungstenite::{WebSocketStream, accept_hdr_async};
13use ustreamer_proto::frame::FramePacket;
14use ustreamer_proto::input::InputEvent;
15
16use crate::{InputReliability, ReceivedInput, TransportError};
17
18/// An accepted WebSocket session upgrade.
19pub struct AcceptedWebSocketSession {
20    /// Path requested by the client (for example `/stream`).
21    pub path: String,
22    /// Established WebSocket session.
23    pub session: WebSocketSession,
24}
25
26/// A TCP listener that upgrades incoming requests to WebSocket sessions.
27pub struct WebSocketServer {
28    listener: TcpListener,
29}
30
31impl WebSocketServer {
32    /// Bind a WebSocket fallback endpoint on the provided TCP address.
33    pub async fn bind(bind_address: SocketAddr) -> Result<Self, TransportError> {
34        let listener = TcpListener::bind(bind_address)
35            .await
36            .map_err(|err| TransportError::InitFailed(err.to_string()))?;
37        Ok(Self { listener })
38    }
39
40    /// Returns the local socket address of the bound TCP listener.
41    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
42        self.listener.local_addr()
43    }
44
45    /// Accept the next WebSocket session and complete the upgrade handshake.
46    pub async fn accept_session(&self) -> Result<AcceptedWebSocketSession, TransportError> {
47        let (stream, remote_address) = self
48            .listener
49            .accept()
50            .await
51            .map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
52
53        let path = Arc::new(StdMutex::new(None::<String>));
54        let path_capture = Arc::clone(&path);
55        let websocket = accept_hdr_async(stream, move |request: &Request, response: Response| {
56            if let Ok(mut slot) = path_capture.lock() {
57                *slot = Some(request.uri().path().to_owned());
58            }
59            Ok(response)
60        })
61        .await
62        .map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
63
64        let path = path
65            .lock()
66            .ok()
67            .and_then(|mut slot| slot.take())
68            .unwrap_or_else(|| "/".to_owned());
69        let (writer, reader) = websocket.split();
70
71        Ok(AcceptedWebSocketSession {
72            path,
73            session: WebSocketSession {
74                writer: Arc::new(Mutex::new(writer)),
75                reader: Arc::new(Mutex::new(reader)),
76                remote_address,
77            },
78        })
79    }
80}
81
82/// Established WebSocket fallback session.
83#[derive(Clone)]
84pub struct WebSocketSession {
85    writer: Arc<Mutex<SplitSink<WebSocketStream<TcpStream>, Message>>>,
86    reader: Arc<Mutex<SplitStream<WebSocketStream<TcpStream>>>>,
87    remote_address: SocketAddr,
88}
89
90impl WebSocketSession {
91    /// Current peer address.
92    pub fn remote_address(&self) -> SocketAddr {
93        self.remote_address
94    }
95
96    /// Send a single packetized frame fragment over WebSocket binary transport.
97    pub async fn send_frame_packet(&self, packet: &FramePacket) -> Result<(), TransportError> {
98        self.send_message(Message::Binary(packet.to_bytes().into()))
99            .await
100    }
101
102    /// Send a batch of frame fragments in order.
103    pub async fn send_frame_packets(&self, packets: &[FramePacket]) -> Result<(), TransportError> {
104        for packet in packets {
105            self.send_frame_packet(packet).await?;
106        }
107
108        Ok(())
109    }
110
111    /// Send a reliable UTF-8 control message to the browser.
112    pub async fn send_control_message(&self, payload: &[u8]) -> Result<(), TransportError> {
113        let text = String::from_utf8(payload.to_vec()).map_err(|err| {
114            TransportError::StreamIo(format!("control payload was not utf-8: {err}"))
115        })?;
116        self.send_message(Message::Text(text.into())).await
117    }
118
119    /// Receive the next reliable input event from the browser.
120    pub async fn recv_reliable_input(&self) -> Result<InputEvent, TransportError> {
121        let bytes = self.recv_binary_message().await?;
122        InputEvent::from_bytes(&bytes)
123            .map_err(|err| TransportError::InvalidInputEvent(err.to_string()))
124    }
125
126    /// Receive the next input event from the browser.
127    pub async fn recv_input(&self) -> Result<ReceivedInput, TransportError> {
128        Ok(ReceivedInput {
129            reliability: InputReliability::Reliable,
130            event: self.recv_reliable_input().await?,
131        })
132    }
133
134    async fn send_message(&self, message: Message) -> Result<(), TransportError> {
135        let mut writer = self.writer.lock().await;
136        writer
137            .send(message)
138            .await
139            .map_err(|err| TransportError::StreamIo(err.to_string()))
140    }
141
142    async fn recv_binary_message(&self) -> Result<Vec<u8>, TransportError> {
143        loop {
144            let next_message = {
145                let mut reader = self.reader.lock().await;
146                reader.next().await
147            };
148
149            match next_message {
150                Some(Ok(Message::Binary(bytes))) => return Ok(bytes.to_vec()),
151                Some(Ok(Message::Text(_))) => {
152                    return Err(TransportError::InvalidInputEvent(
153                        "expected binary input event over WebSocket".into(),
154                    ));
155                }
156                Some(Ok(Message::Close(_))) | None => return Err(TransportError::SessionClosed),
157                Some(Ok(Message::Ping(_)))
158                | Some(Ok(Message::Pong(_)))
159                | Some(Ok(Message::Frame(_))) => continue,
160                Some(Err(err)) => return Err(TransportError::StreamIo(err.to_string())),
161            }
162        }
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use anyhow::{Result, anyhow};
169    use futures_util::{SinkExt, StreamExt};
170    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
171    use tokio::time::{Duration, timeout};
172    use tokio_tungstenite::connect_async;
173    use tokio_tungstenite::tungstenite::Message;
174
175    use super::*;
176
177    struct LoopbackPair {
178        _server: WebSocketServer,
179        server_session: WebSocketSession,
180        client_socket: WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
181        path: String,
182    }
183
184    async fn loopback_pair() -> Result<LoopbackPair> {
185        let bind_address = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
186        let server = WebSocketServer::bind(bind_address).await?;
187        let port = server.local_addr()?.port();
188        let url = format!("ws://127.0.0.1:{port}/stream");
189
190        let (accepted, client) = tokio::join!(server.accept_session(), connect_async(url));
191        let accepted = accepted?;
192        let (client_socket, _) = client?;
193
194        Ok(LoopbackPair {
195            _server: server,
196            server_session: accepted.session,
197            client_socket,
198            path: accepted.path,
199        })
200    }
201
202    #[tokio::test]
203    async fn accepts_websocket_session_and_receives_input() -> Result<()> {
204        let mut pair = loopback_pair().await?;
205        assert_eq!(pair.path, "/stream");
206
207        let event = InputEvent::KeyDown { code: 0x0041 };
208        pair.client_socket
209            .send(Message::Binary(event.to_bytes().into()))
210            .await?;
211
212        let received = timeout(
213            Duration::from_secs(2),
214            pair.server_session.recv_reliable_input(),
215        )
216        .await??;
217        match received {
218            InputEvent::KeyDown { code } => assert_eq!(code, 0x0041),
219            other => panic!("unexpected input event: {other:?}"),
220        }
221
222        Ok(())
223    }
224
225    #[tokio::test]
226    async fn sends_frame_packets_over_websocket_binary_messages() -> Result<()> {
227        let mut pair = loopback_pair().await?;
228
229        let packet = FramePacket {
230            frame_id: 7,
231            fragment_idx: 0,
232            fragment_count: 1,
233            timestamp_us: 999,
234            is_keyframe: true,
235            is_refine: true,
236            is_lossless: true,
237            payload: vec![1, 2, 3, 4],
238        };
239
240        pair.server_session.send_frame_packet(&packet).await?;
241
242        let message = timeout(Duration::from_secs(2), pair.client_socket.next())
243            .await?
244            .transpose()?
245            .ok_or_else(|| anyhow!("client websocket closed"))?;
246        let Message::Binary(bytes) = message else {
247            panic!("expected binary frame message");
248        };
249        let decoded = FramePacket::from_bytes(&bytes)?;
250        assert_eq!(decoded.frame_id, 7);
251        assert!(decoded.is_keyframe);
252        assert!(decoded.is_refine);
253        assert!(decoded.is_lossless);
254
255        Ok(())
256    }
257
258    #[tokio::test]
259    async fn sends_control_messages_as_text() -> Result<()> {
260        let mut pair = loopback_pair().await?;
261
262        pair.server_session
263            .send_control_message(br#"{"type":"status","message":"ok"}"#)
264            .await?;
265
266        let message = timeout(Duration::from_secs(2), pair.client_socket.next())
267            .await?
268            .transpose()?
269            .ok_or_else(|| anyhow!("client websocket closed"))?;
270        let Message::Text(text) = message else {
271            panic!("expected text control message");
272        };
273        assert_eq!(text, r#"{"type":"status","message":"ok"}"#);
274
275        Ok(())
276    }
277}