Skip to main content

ustreamer_transport/
session.rs

1//! WebTransport session management primitives built on top of `wtransport`.
2
3use std::net::SocketAddr;
4use std::time::Duration;
5
6use ustreamer_proto::frame::FramePacket;
7use ustreamer_proto::input::InputEvent;
8use wtransport::endpoint::endpoint_side::Server;
9use wtransport::tls::Sha256Digest;
10use wtransport::{Connection, Endpoint, Identity, ServerConfig};
11
12use crate::TransportError;
13
14/// Server-side TLS identity source.
15pub enum ServerIdentity {
16    /// Use an existing certificate chain and private key.
17    Provided(Identity),
18    /// Generate a two-week self-signed identity for a known set of hostnames/IPs.
19    SelfSigned { subject_alt_names: Vec<String> },
20}
21
22impl ServerIdentity {
23    fn into_identity_and_hash(self) -> Result<(Identity, Sha256Digest), TransportError> {
24        let identity = match self {
25            ServerIdentity::Provided(identity) => identity,
26            ServerIdentity::SelfSigned { subject_alt_names } => {
27                Identity::self_signed(subject_alt_names.iter().map(String::as_str))
28                    .map_err(|err| TransportError::InitFailed(err.to_string()))?
29            }
30        };
31
32        let certificate_hash = {
33            let chain = identity.certificate_chain();
34            let Some(certificate) = chain.as_slice().first() else {
35                return Err(TransportError::InitFailed(
36                    "identity did not contain a certificate".to_owned(),
37                ));
38            };
39
40            certificate.hash()
41        };
42
43        Ok((identity, certificate_hash))
44    }
45}
46
47/// Transport-layer configuration for the WebTransport endpoint.
48pub struct TransportConfig {
49    /// UDP socket bind address for the server endpoint.
50    pub bind_address: SocketAddr,
51    /// TLS identity used during the WebTransport handshake.
52    pub identity: ServerIdentity,
53    /// Keep-alive interval for preserving low-latency LAN sessions.
54    pub keep_alive_interval: Option<Duration>,
55    /// Maximum permitted idle time before the connection is timed out.
56    pub max_idle_timeout: Option<Duration>,
57}
58
59impl TransportConfig {
60    /// Convenience helper for local development with a self-signed identity.
61    pub fn localhost_self_signed(bind_address: SocketAddr) -> Self {
62        Self {
63            bind_address,
64            identity: ServerIdentity::SelfSigned {
65                subject_alt_names: vec!["localhost".to_owned(), "127.0.0.1".to_owned()],
66            },
67            keep_alive_interval: Some(Duration::from_secs(3)),
68            max_idle_timeout: Some(Duration::from_secs(10)),
69        }
70    }
71}
72
73/// An accepted WebTransport session request and the established session.
74pub struct AcceptedSession {
75    /// Host/authority requested by the client.
76    pub authority: String,
77    /// Path requested by the client (e.g. `/stream`).
78    pub path: String,
79    /// Session handle for video, control, and input traffic.
80    pub session: StreamSession,
81}
82
83/// WebTransport server endpoint that accepts browser sessions.
84pub struct WebTransportServer {
85    endpoint: Endpoint<Server>,
86    certificate_hash: Sha256Digest,
87}
88
89impl WebTransportServer {
90    /// Bind a WebTransport server endpoint on the configured socket.
91    pub fn bind(config: TransportConfig) -> Result<Self, TransportError> {
92        let (identity, certificate_hash) = config.identity.into_identity_and_hash()?;
93
94        let server_config = ServerConfig::builder()
95            .with_bind_address(config.bind_address)
96            .with_identity(identity)
97            .keep_alive_interval(config.keep_alive_interval)
98            .max_idle_timeout(config.max_idle_timeout)
99            .map_err(|err| TransportError::InitFailed(err.to_string()))?
100            .build();
101
102        let endpoint = Endpoint::server(server_config)
103            .map_err(|err| TransportError::InitFailed(err.to_string()))?;
104
105        Ok(Self {
106            endpoint,
107            certificate_hash,
108        })
109    }
110
111    /// Returns the local socket address of the bound UDP endpoint.
112    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
113        self.endpoint.local_addr()
114    }
115
116    /// Returns the certificate digest browsers/clients can pin during setup.
117    pub fn certificate_hash(&self) -> &Sha256Digest {
118        &self.certificate_hash
119    }
120
121    /// Accept the next WebTransport session and complete the session handshake.
122    pub async fn accept_session(&self) -> Result<AcceptedSession, TransportError> {
123        let incoming = self.endpoint.accept().await;
124        let request = incoming
125            .await
126            .map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
127
128        let authority = request.authority().to_string();
129        let path = request.path().to_string();
130        let connection = request
131            .accept()
132            .await
133            .map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
134
135        Ok(AcceptedSession {
136            authority,
137            path,
138            session: StreamSession { connection },
139        })
140    }
141}
142
143/// Whether an input event arrived unreliably (datagram) or reliably (stream).
144#[derive(Debug, Clone, Copy, PartialEq, Eq)]
145pub enum InputReliability {
146    Unreliable,
147    Reliable,
148}
149
150/// Input event received from the browser session.
151#[derive(Debug, Clone, Copy)]
152pub struct ReceivedInput {
153    pub reliability: InputReliability,
154    pub event: InputEvent,
155}
156
157/// Established WebTransport session used by the streaming loop.
158#[derive(Clone)]
159pub struct StreamSession {
160    connection: Connection,
161}
162
163impl StreamSession {
164    /// Current estimate of round-trip time for the session.
165    pub fn rtt(&self) -> Duration {
166        self.connection.rtt()
167    }
168
169    /// Current peer address.
170    pub fn remote_address(&self) -> SocketAddr {
171        self.connection.remote_address()
172    }
173
174    /// Maximum datagram payload permitted by the current path MTU estimate.
175    pub fn max_datagram_size(&self) -> Option<usize> {
176        self.connection.max_datagram_size()
177    }
178
179    /// Send a single packetized frame fragment over QUIC datagram transport.
180    pub fn send_frame_packet(&self, packet: &FramePacket) -> Result<(), TransportError> {
181        let bytes = packet.to_bytes();
182        self.send_datagram(&bytes)
183    }
184
185    /// Send a batch of packetized frame fragments in order.
186    pub fn send_frame_packets(&self, packets: &[FramePacket]) -> Result<(), TransportError> {
187        for packet in packets {
188            self.send_frame_packet(packet)?;
189        }
190
191        Ok(())
192    }
193
194    /// Receive the next unreliable input datagram from the browser.
195    pub async fn recv_input_datagram(&self) -> Result<InputEvent, TransportError> {
196        let datagram = self
197            .connection
198            .receive_datagram()
199            .await
200            .map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
201
202        InputEvent::from_bytes(datagram.as_ref())
203            .map_err(|err| TransportError::InvalidInputEvent(err.to_string()))
204    }
205
206    /// Receive the next reliable input message from a uni- or bidirectional stream.
207    pub async fn recv_reliable_input(&self) -> Result<InputEvent, TransportError> {
208        let message = self.recv_reliable_message().await?;
209        InputEvent::from_bytes(&message)
210            .map_err(|err| TransportError::InvalidInputEvent(err.to_string()))
211    }
212
213    /// Receive the next input event, regardless of reliability mode.
214    pub async fn recv_input(&self) -> Result<ReceivedInput, TransportError> {
215        let datagram_connection = self.connection.clone();
216        let reliable_connection = self.connection.clone();
217
218        tokio::select! {
219            datagram = datagram_connection.receive_datagram() => {
220                let datagram = datagram.map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
221                let event = InputEvent::from_bytes(datagram.as_ref())
222                    .map_err(|err| TransportError::InvalidInputEvent(err.to_string()))?;
223
224                Ok(ReceivedInput {
225                    reliability: InputReliability::Unreliable,
226                    event,
227                })
228            }
229            reliable = recv_reliable_message_from(reliable_connection) => {
230                let bytes = reliable?;
231                let event = InputEvent::from_bytes(&bytes)
232                    .map_err(|err| TransportError::InvalidInputEvent(err.to_string()))?;
233
234                Ok(ReceivedInput {
235                    reliability: InputReliability::Reliable,
236                    event,
237                })
238            }
239        }
240    }
241
242    /// Send a reliable control message to the browser using a unidirectional stream.
243    pub async fn send_control_message(&self, payload: &[u8]) -> Result<(), TransportError> {
244        let mut stream = self
245            .connection
246            .open_uni()
247            .await
248            .map_err(|err| TransportError::ConnectionFailed(err.to_string()))?
249            .await
250            .map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
251
252        stream
253            .write_all(payload)
254            .await
255            .map_err(|err| TransportError::StreamIo(err.to_string()))
256    }
257
258    fn send_datagram(&self, payload: &[u8]) -> Result<(), TransportError> {
259        let max = self
260            .max_datagram_size()
261            .ok_or(TransportError::DatagramsUnsupported)?;
262
263        if payload.len() > max {
264            return Err(TransportError::DatagramTooLarge {
265                size: payload.len(),
266                max,
267            });
268        }
269
270        self.connection
271            .send_datagram(payload)
272            .map_err(|err| TransportError::ConnectionFailed(err.to_string()))
273    }
274
275    async fn recv_reliable_message(&self) -> Result<Vec<u8>, TransportError> {
276        recv_reliable_message_from(self.connection.clone()).await
277    }
278}
279
280async fn recv_reliable_message_from(connection: Connection) -> Result<Vec<u8>, TransportError> {
281    let uni_connection = connection.clone();
282    let bi_connection = connection;
283
284    tokio::select! {
285        uni = uni_connection.accept_uni() => {
286            let mut stream = uni.map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
287            read_all(&mut stream).await
288        }
289        bi = bi_connection.accept_bi() => {
290            let (_, mut stream) = bi.map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
291            read_all(&mut stream).await
292        }
293    }
294}
295
296async fn read_all(stream: &mut wtransport::RecvStream) -> Result<Vec<u8>, TransportError> {
297    let mut output = Vec::new();
298    let mut buffer = vec![0u8; 4096];
299
300    loop {
301        let bytes_read = stream
302            .read(&mut buffer)
303            .await
304            .map_err(|err| TransportError::StreamIo(err.to_string()))?;
305
306        match bytes_read {
307            Some(0) => break,
308            Some(bytes_read) => output.extend_from_slice(&buffer[..bytes_read]),
309            None => break,
310        }
311    }
312
313    Ok(output)
314}
315
316#[cfg(test)]
317mod tests {
318    use anyhow::Result;
319    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
320    use tokio::time::{Duration, timeout};
321    use wtransport::endpoint::endpoint_side::Client;
322    use wtransport::{ClientConfig, Endpoint};
323
324    use super::*;
325
326    struct LoopbackPair {
327        _server: WebTransportServer,
328        _client_endpoint: Endpoint<Client>,
329        server_session: StreamSession,
330        client_connection: Connection,
331        path: String,
332    }
333
334    async fn loopback_pair() -> Result<LoopbackPair> {
335        let bind_address = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
336        let server =
337            WebTransportServer::bind(TransportConfig::localhost_self_signed(bind_address))?;
338        let cert_hash = server.certificate_hash().clone();
339        let port = server.local_addr()?.port();
340
341        let client_config = ClientConfig::builder()
342            .with_bind_default()
343            .with_server_certificate_hashes([cert_hash])
344            .build();
345
346        let client_endpoint = Endpoint::client(client_config)?;
347        let url = format!("https://127.0.0.1:{port}/stream");
348
349        let (accepted, client_connection) = tokio::join!(
350            async {
351                Ok::<_, anyhow::Error>(
352                    timeout(Duration::from_secs(5), server.accept_session()).await??,
353                )
354            },
355            async {
356                Ok::<_, anyhow::Error>(
357                    timeout(Duration::from_secs(5), client_endpoint.connect(url)).await??,
358                )
359            }
360        );
361
362        let accepted = accepted?;
363        let client_connection = client_connection?;
364
365        Ok(LoopbackPair {
366            _server: server,
367            _client_endpoint: client_endpoint,
368            server_session: accepted.session,
369            client_connection,
370            path: accepted.path,
371        })
372    }
373
374    async fn read_client_stream(stream: &mut wtransport::RecvStream) -> Result<Vec<u8>> {
375        let mut output = Vec::new();
376        let mut buffer = vec![0u8; 4096];
377
378        loop {
379            let bytes_read = stream.read(&mut buffer).await?;
380            match bytes_read {
381                Some(0) => break,
382                Some(bytes_read) => output.extend_from_slice(&buffer[..bytes_read]),
383                None => break,
384            }
385        }
386
387        Ok(output)
388    }
389
390    #[tokio::test]
391    async fn accepts_session_and_receives_input_datagram() -> Result<()> {
392        let pair = loopback_pair().await?;
393        assert_eq!(pair.path, "/stream");
394
395        let input = InputEvent::PointerMove {
396            x: 0.25,
397            y: 0.75,
398            buttons: 1,
399            timestamp_ms: 4242,
400        };
401
402        pair.client_connection.send_datagram(&input.to_bytes())?;
403
404        let received = timeout(
405            Duration::from_secs(5),
406            pair.server_session.recv_input_datagram(),
407        )
408        .await??;
409
410        match received {
411            InputEvent::PointerMove {
412                x,
413                y,
414                buttons,
415                timestamp_ms,
416            } => {
417                assert!((x - 0.25).abs() < f32::EPSILON);
418                assert!((y - 0.75).abs() < f32::EPSILON);
419                assert_eq!(buttons, 1);
420                assert_eq!(timestamp_ms, 4242);
421            }
422            _ => panic!("expected pointer move"),
423        }
424
425        Ok(())
426    }
427
428    #[tokio::test]
429    async fn sends_frame_packets_over_datagrams() -> Result<()> {
430        let pair = loopback_pair().await?;
431
432        let packet = FramePacket {
433            frame_id: 7,
434            fragment_idx: 0,
435            fragment_count: 1,
436            timestamp_us: 123_456,
437            is_keyframe: true,
438            is_refine: false,
439            is_lossless: false,
440            payload: vec![1, 2, 3, 4, 5],
441        };
442
443        pair.server_session.send_frame_packet(&packet)?;
444
445        let datagram = timeout(
446            Duration::from_secs(5),
447            pair.client_connection.receive_datagram(),
448        )
449        .await??;
450        let decoded = FramePacket::from_bytes(datagram.as_ref())?;
451
452        assert_eq!(decoded.frame_id, 7);
453        assert_eq!(decoded.fragment_idx, 0);
454        assert_eq!(decoded.fragment_count, 1);
455        assert_eq!(decoded.timestamp_us, 123_456);
456        assert!(decoded.is_keyframe);
457        assert!(!decoded.is_refine);
458        assert!(!decoded.is_lossless);
459        assert_eq!(decoded.payload, vec![1, 2, 3, 4, 5]);
460
461        Ok(())
462    }
463
464    #[tokio::test]
465    async fn receives_reliable_input_and_sends_control_message() -> Result<()> {
466        let pair = loopback_pair().await?;
467
468        let mut send_stream = pair.client_connection.open_uni().await?.await?;
469        send_stream
470            .write_all(&InputEvent::KeyDown { code: 0x0041 }.to_bytes())
471            .await?;
472        drop(send_stream);
473
474        let received = timeout(
475            Duration::from_secs(5),
476            pair.server_session.recv_reliable_input(),
477        )
478        .await??;
479
480        match received {
481            InputEvent::KeyDown { code } => assert_eq!(code, 0x0041),
482            _ => panic!("expected key down"),
483        }
484
485        let control_message = b"codec=h265;mode=interactive";
486        pair.server_session
487            .send_control_message(control_message)
488            .await?;
489
490        let mut recv_stream =
491            timeout(Duration::from_secs(5), pair.client_connection.accept_uni()).await??;
492        let payload = read_client_stream(&mut recv_stream).await?;
493        assert_eq!(payload, control_message);
494
495        Ok(())
496    }
497}