ustreamer_transport/
websocket.rs1use std::net::SocketAddr;
4use std::sync::{Arc, Mutex as StdMutex};
5
6use futures_util::stream::{SplitSink, SplitStream};
7use futures_util::{SinkExt, StreamExt};
8use tokio::net::{TcpListener, TcpStream};
9use tokio::sync::Mutex;
10use tokio_tungstenite::tungstenite::Message;
11use tokio_tungstenite::tungstenite::handshake::server::{Request, Response};
12use tokio_tungstenite::{WebSocketStream, accept_hdr_async};
13use ustreamer_proto::frame::FramePacket;
14use ustreamer_proto::input::InputEvent;
15
16use crate::{InputReliability, ReceivedInput, TransportError};
17
18pub struct AcceptedWebSocketSession {
20 pub path: String,
22 pub session: WebSocketSession,
24}
25
26pub struct WebSocketServer {
28 listener: TcpListener,
29}
30
31impl WebSocketServer {
32 pub async fn bind(bind_address: SocketAddr) -> Result<Self, TransportError> {
34 let listener = TcpListener::bind(bind_address)
35 .await
36 .map_err(|err| TransportError::InitFailed(err.to_string()))?;
37 Ok(Self { listener })
38 }
39
40 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
42 self.listener.local_addr()
43 }
44
45 pub async fn accept_session(&self) -> Result<AcceptedWebSocketSession, TransportError> {
47 let (stream, remote_address) = self
48 .listener
49 .accept()
50 .await
51 .map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
52
53 let path = Arc::new(StdMutex::new(None::<String>));
54 let path_capture = Arc::clone(&path);
55 let websocket = accept_hdr_async(stream, move |request: &Request, response: Response| {
56 if let Ok(mut slot) = path_capture.lock() {
57 *slot = Some(request.uri().path().to_owned());
58 }
59 Ok(response)
60 })
61 .await
62 .map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
63
64 let path = path
65 .lock()
66 .ok()
67 .and_then(|mut slot| slot.take())
68 .unwrap_or_else(|| "/".to_owned());
69 let (writer, reader) = websocket.split();
70
71 Ok(AcceptedWebSocketSession {
72 path,
73 session: WebSocketSession {
74 writer: Arc::new(Mutex::new(writer)),
75 reader: Arc::new(Mutex::new(reader)),
76 remote_address,
77 },
78 })
79 }
80}
81
82#[derive(Clone)]
84pub struct WebSocketSession {
85 writer: Arc<Mutex<SplitSink<WebSocketStream<TcpStream>, Message>>>,
86 reader: Arc<Mutex<SplitStream<WebSocketStream<TcpStream>>>>,
87 remote_address: SocketAddr,
88}
89
90impl WebSocketSession {
91 pub fn remote_address(&self) -> SocketAddr {
93 self.remote_address
94 }
95
96 pub async fn send_frame_packet(&self, packet: &FramePacket) -> Result<(), TransportError> {
98 self.send_message(Message::Binary(packet.to_bytes().into()))
99 .await
100 }
101
102 pub async fn send_frame_packets(&self, packets: &[FramePacket]) -> Result<(), TransportError> {
104 for packet in packets {
105 self.send_frame_packet(packet).await?;
106 }
107
108 Ok(())
109 }
110
111 pub async fn send_control_message(&self, payload: &[u8]) -> Result<(), TransportError> {
113 let text = String::from_utf8(payload.to_vec()).map_err(|err| {
114 TransportError::StreamIo(format!("control payload was not utf-8: {err}"))
115 })?;
116 self.send_message(Message::Text(text.into())).await
117 }
118
119 pub async fn recv_reliable_input(&self) -> Result<InputEvent, TransportError> {
121 let bytes = self.recv_binary_message().await?;
122 InputEvent::from_bytes(&bytes)
123 .map_err(|err| TransportError::InvalidInputEvent(err.to_string()))
124 }
125
126 pub async fn recv_input(&self) -> Result<ReceivedInput, TransportError> {
128 Ok(ReceivedInput {
129 reliability: InputReliability::Reliable,
130 event: self.recv_reliable_input().await?,
131 })
132 }
133
134 async fn send_message(&self, message: Message) -> Result<(), TransportError> {
135 let mut writer = self.writer.lock().await;
136 writer
137 .send(message)
138 .await
139 .map_err(|err| TransportError::StreamIo(err.to_string()))
140 }
141
142 async fn recv_binary_message(&self) -> Result<Vec<u8>, TransportError> {
143 loop {
144 let next_message = {
145 let mut reader = self.reader.lock().await;
146 reader.next().await
147 };
148
149 match next_message {
150 Some(Ok(Message::Binary(bytes))) => return Ok(bytes.to_vec()),
151 Some(Ok(Message::Text(_))) => {
152 return Err(TransportError::InvalidInputEvent(
153 "expected binary input event over WebSocket".into(),
154 ));
155 }
156 Some(Ok(Message::Close(_))) | None => return Err(TransportError::SessionClosed),
157 Some(Ok(Message::Ping(_)))
158 | Some(Ok(Message::Pong(_)))
159 | Some(Ok(Message::Frame(_))) => continue,
160 Some(Err(err)) => return Err(TransportError::StreamIo(err.to_string())),
161 }
162 }
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use anyhow::{Result, anyhow};
169 use futures_util::{SinkExt, StreamExt};
170 use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
171 use tokio::time::{Duration, timeout};
172 use tokio_tungstenite::connect_async;
173 use tokio_tungstenite::tungstenite::Message;
174
175 use super::*;
176
177 struct LoopbackPair {
178 _server: WebSocketServer,
179 server_session: WebSocketSession,
180 client_socket: WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
181 path: String,
182 }
183
184 async fn loopback_pair() -> Result<LoopbackPair> {
185 let bind_address = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
186 let server = WebSocketServer::bind(bind_address).await?;
187 let port = server.local_addr()?.port();
188 let url = format!("ws://127.0.0.1:{port}/stream");
189
190 let (accepted, client) = tokio::join!(server.accept_session(), connect_async(url));
191 let accepted = accepted?;
192 let (client_socket, _) = client?;
193
194 Ok(LoopbackPair {
195 _server: server,
196 server_session: accepted.session,
197 client_socket,
198 path: accepted.path,
199 })
200 }
201
202 #[tokio::test]
203 async fn accepts_websocket_session_and_receives_input() -> Result<()> {
204 let mut pair = loopback_pair().await?;
205 assert_eq!(pair.path, "/stream");
206
207 let event = InputEvent::KeyDown { code: 0x0041 };
208 pair.client_socket
209 .send(Message::Binary(event.to_bytes().into()))
210 .await?;
211
212 let received = timeout(
213 Duration::from_secs(2),
214 pair.server_session.recv_reliable_input(),
215 )
216 .await??;
217 match received {
218 InputEvent::KeyDown { code } => assert_eq!(code, 0x0041),
219 other => panic!("unexpected input event: {other:?}"),
220 }
221
222 Ok(())
223 }
224
225 #[tokio::test]
226 async fn sends_frame_packets_over_websocket_binary_messages() -> Result<()> {
227 let mut pair = loopback_pair().await?;
228
229 let packet = FramePacket {
230 frame_id: 7,
231 fragment_idx: 0,
232 fragment_count: 1,
233 timestamp_us: 999,
234 is_keyframe: true,
235 is_refine: true,
236 is_lossless: true,
237 payload: vec![1, 2, 3, 4],
238 };
239
240 pair.server_session.send_frame_packet(&packet).await?;
241
242 let message = timeout(Duration::from_secs(2), pair.client_socket.next())
243 .await?
244 .transpose()?
245 .ok_or_else(|| anyhow!("client websocket closed"))?;
246 let Message::Binary(bytes) = message else {
247 panic!("expected binary frame message");
248 };
249 let decoded = FramePacket::from_bytes(&bytes)?;
250 assert_eq!(decoded.frame_id, 7);
251 assert!(decoded.is_keyframe);
252 assert!(decoded.is_refine);
253 assert!(decoded.is_lossless);
254
255 Ok(())
256 }
257
258 #[tokio::test]
259 async fn sends_control_messages_as_text() -> Result<()> {
260 let mut pair = loopback_pair().await?;
261
262 pair.server_session
263 .send_control_message(br#"{"type":"status","message":"ok"}"#)
264 .await?;
265
266 let message = timeout(Duration::from_secs(2), pair.client_socket.next())
267 .await?
268 .transpose()?
269 .ok_or_else(|| anyhow!("client websocket closed"))?;
270 let Message::Text(text) = message else {
271 panic!("expected text control message");
272 };
273 assert_eq!(text, r#"{"type":"status","message":"ok"}"#);
274
275 Ok(())
276 }
277}