1use std::net::SocketAddr;
4use std::time::Duration;
5
6use ustreamer_proto::frame::FramePacket;
7use ustreamer_proto::input::InputEvent;
8use wtransport::endpoint::endpoint_side::Server;
9use wtransport::tls::Sha256Digest;
10use wtransport::{Connection, Endpoint, Identity, ServerConfig};
11
12use crate::TransportError;
13
14pub enum ServerIdentity {
16 Provided(Identity),
18 SelfSigned { subject_alt_names: Vec<String> },
20}
21
22impl ServerIdentity {
23 fn into_identity_and_hash(self) -> Result<(Identity, Sha256Digest), TransportError> {
24 let identity = match self {
25 ServerIdentity::Provided(identity) => identity,
26 ServerIdentity::SelfSigned { subject_alt_names } => {
27 Identity::self_signed(subject_alt_names.iter().map(String::as_str))
28 .map_err(|err| TransportError::InitFailed(err.to_string()))?
29 }
30 };
31
32 let certificate_hash = {
33 let chain = identity.certificate_chain();
34 let Some(certificate) = chain.as_slice().first() else {
35 return Err(TransportError::InitFailed(
36 "identity did not contain a certificate".to_owned(),
37 ));
38 };
39
40 certificate.hash()
41 };
42
43 Ok((identity, certificate_hash))
44 }
45}
46
47pub struct TransportConfig {
49 pub bind_address: SocketAddr,
51 pub identity: ServerIdentity,
53 pub keep_alive_interval: Option<Duration>,
55 pub max_idle_timeout: Option<Duration>,
57}
58
59impl TransportConfig {
60 pub fn localhost_self_signed(bind_address: SocketAddr) -> Self {
62 Self {
63 bind_address,
64 identity: ServerIdentity::SelfSigned {
65 subject_alt_names: vec!["localhost".to_owned(), "127.0.0.1".to_owned()],
66 },
67 keep_alive_interval: Some(Duration::from_secs(3)),
68 max_idle_timeout: Some(Duration::from_secs(10)),
69 }
70 }
71}
72
73pub struct AcceptedSession {
75 pub authority: String,
77 pub path: String,
79 pub session: StreamSession,
81}
82
83pub struct WebTransportServer {
85 endpoint: Endpoint<Server>,
86 certificate_hash: Sha256Digest,
87}
88
89impl WebTransportServer {
90 pub fn bind(config: TransportConfig) -> Result<Self, TransportError> {
92 let (identity, certificate_hash) = config.identity.into_identity_and_hash()?;
93
94 let server_config = ServerConfig::builder()
95 .with_bind_address(config.bind_address)
96 .with_identity(identity)
97 .keep_alive_interval(config.keep_alive_interval)
98 .max_idle_timeout(config.max_idle_timeout)
99 .map_err(|err| TransportError::InitFailed(err.to_string()))?
100 .build();
101
102 let endpoint = Endpoint::server(server_config)
103 .map_err(|err| TransportError::InitFailed(err.to_string()))?;
104
105 Ok(Self {
106 endpoint,
107 certificate_hash,
108 })
109 }
110
111 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
113 self.endpoint.local_addr()
114 }
115
116 pub fn certificate_hash(&self) -> &Sha256Digest {
118 &self.certificate_hash
119 }
120
121 pub async fn accept_session(&self) -> Result<AcceptedSession, TransportError> {
123 let incoming = self.endpoint.accept().await;
124 let request = incoming
125 .await
126 .map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
127
128 let authority = request.authority().to_string();
129 let path = request.path().to_string();
130 let connection = request
131 .accept()
132 .await
133 .map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
134
135 Ok(AcceptedSession {
136 authority,
137 path,
138 session: StreamSession { connection },
139 })
140 }
141}
142
143#[derive(Debug, Clone, Copy, PartialEq, Eq)]
145pub enum InputReliability {
146 Unreliable,
147 Reliable,
148}
149
150#[derive(Debug, Clone, Copy)]
152pub struct ReceivedInput {
153 pub reliability: InputReliability,
154 pub event: InputEvent,
155}
156
157#[derive(Clone)]
159pub struct StreamSession {
160 connection: Connection,
161}
162
163impl StreamSession {
164 pub fn rtt(&self) -> Duration {
166 self.connection.rtt()
167 }
168
169 pub fn remote_address(&self) -> SocketAddr {
171 self.connection.remote_address()
172 }
173
174 pub fn max_datagram_size(&self) -> Option<usize> {
176 self.connection.max_datagram_size()
177 }
178
179 pub fn send_frame_packet(&self, packet: &FramePacket) -> Result<(), TransportError> {
181 let bytes = packet.to_bytes();
182 self.send_datagram(&bytes)
183 }
184
185 pub fn send_frame_packets(&self, packets: &[FramePacket]) -> Result<(), TransportError> {
187 for packet in packets {
188 self.send_frame_packet(packet)?;
189 }
190
191 Ok(())
192 }
193
194 pub async fn recv_input_datagram(&self) -> Result<InputEvent, TransportError> {
196 let datagram = self
197 .connection
198 .receive_datagram()
199 .await
200 .map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
201
202 InputEvent::from_bytes(datagram.as_ref())
203 .map_err(|err| TransportError::InvalidInputEvent(err.to_string()))
204 }
205
206 pub async fn recv_reliable_input(&self) -> Result<InputEvent, TransportError> {
208 let message = self.recv_reliable_message().await?;
209 InputEvent::from_bytes(&message)
210 .map_err(|err| TransportError::InvalidInputEvent(err.to_string()))
211 }
212
213 pub async fn recv_input(&self) -> Result<ReceivedInput, TransportError> {
215 let datagram_connection = self.connection.clone();
216 let reliable_connection = self.connection.clone();
217
218 tokio::select! {
219 datagram = datagram_connection.receive_datagram() => {
220 let datagram = datagram.map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
221 let event = InputEvent::from_bytes(datagram.as_ref())
222 .map_err(|err| TransportError::InvalidInputEvent(err.to_string()))?;
223
224 Ok(ReceivedInput {
225 reliability: InputReliability::Unreliable,
226 event,
227 })
228 }
229 reliable = recv_reliable_message_from(reliable_connection) => {
230 let bytes = reliable?;
231 let event = InputEvent::from_bytes(&bytes)
232 .map_err(|err| TransportError::InvalidInputEvent(err.to_string()))?;
233
234 Ok(ReceivedInput {
235 reliability: InputReliability::Reliable,
236 event,
237 })
238 }
239 }
240 }
241
242 pub async fn send_control_message(&self, payload: &[u8]) -> Result<(), TransportError> {
244 let mut stream = self
245 .connection
246 .open_uni()
247 .await
248 .map_err(|err| TransportError::ConnectionFailed(err.to_string()))?
249 .await
250 .map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
251
252 stream
253 .write_all(payload)
254 .await
255 .map_err(|err| TransportError::StreamIo(err.to_string()))
256 }
257
258 fn send_datagram(&self, payload: &[u8]) -> Result<(), TransportError> {
259 let max = self
260 .max_datagram_size()
261 .ok_or(TransportError::DatagramsUnsupported)?;
262
263 if payload.len() > max {
264 return Err(TransportError::DatagramTooLarge {
265 size: payload.len(),
266 max,
267 });
268 }
269
270 self.connection
271 .send_datagram(payload)
272 .map_err(|err| TransportError::ConnectionFailed(err.to_string()))
273 }
274
275 async fn recv_reliable_message(&self) -> Result<Vec<u8>, TransportError> {
276 recv_reliable_message_from(self.connection.clone()).await
277 }
278}
279
280async fn recv_reliable_message_from(connection: Connection) -> Result<Vec<u8>, TransportError> {
281 let uni_connection = connection.clone();
282 let bi_connection = connection;
283
284 tokio::select! {
285 uni = uni_connection.accept_uni() => {
286 let mut stream = uni.map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
287 read_all(&mut stream).await
288 }
289 bi = bi_connection.accept_bi() => {
290 let (_, mut stream) = bi.map_err(|err| TransportError::ConnectionFailed(err.to_string()))?;
291 read_all(&mut stream).await
292 }
293 }
294}
295
296async fn read_all(stream: &mut wtransport::RecvStream) -> Result<Vec<u8>, TransportError> {
297 let mut output = Vec::new();
298 let mut buffer = vec![0u8; 4096];
299
300 loop {
301 let bytes_read = stream
302 .read(&mut buffer)
303 .await
304 .map_err(|err| TransportError::StreamIo(err.to_string()))?;
305
306 match bytes_read {
307 Some(0) => break,
308 Some(bytes_read) => output.extend_from_slice(&buffer[..bytes_read]),
309 None => break,
310 }
311 }
312
313 Ok(output)
314}
315
316#[cfg(test)]
317mod tests {
318 use anyhow::Result;
319 use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
320 use tokio::time::{Duration, timeout};
321 use wtransport::endpoint::endpoint_side::Client;
322 use wtransport::{ClientConfig, Endpoint};
323
324 use super::*;
325
326 struct LoopbackPair {
327 _server: WebTransportServer,
328 _client_endpoint: Endpoint<Client>,
329 server_session: StreamSession,
330 client_connection: Connection,
331 path: String,
332 }
333
334 async fn loopback_pair() -> Result<LoopbackPair> {
335 let bind_address = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
336 let server =
337 WebTransportServer::bind(TransportConfig::localhost_self_signed(bind_address))?;
338 let cert_hash = server.certificate_hash().clone();
339 let port = server.local_addr()?.port();
340
341 let client_config = ClientConfig::builder()
342 .with_bind_default()
343 .with_server_certificate_hashes([cert_hash])
344 .build();
345
346 let client_endpoint = Endpoint::client(client_config)?;
347 let url = format!("https://127.0.0.1:{port}/stream");
348
349 let (accepted, client_connection) = tokio::join!(
350 async {
351 Ok::<_, anyhow::Error>(
352 timeout(Duration::from_secs(5), server.accept_session()).await??,
353 )
354 },
355 async {
356 Ok::<_, anyhow::Error>(
357 timeout(Duration::from_secs(5), client_endpoint.connect(url)).await??,
358 )
359 }
360 );
361
362 let accepted = accepted?;
363 let client_connection = client_connection?;
364
365 Ok(LoopbackPair {
366 _server: server,
367 _client_endpoint: client_endpoint,
368 server_session: accepted.session,
369 client_connection,
370 path: accepted.path,
371 })
372 }
373
374 async fn read_client_stream(stream: &mut wtransport::RecvStream) -> Result<Vec<u8>> {
375 let mut output = Vec::new();
376 let mut buffer = vec![0u8; 4096];
377
378 loop {
379 let bytes_read = stream.read(&mut buffer).await?;
380 match bytes_read {
381 Some(0) => break,
382 Some(bytes_read) => output.extend_from_slice(&buffer[..bytes_read]),
383 None => break,
384 }
385 }
386
387 Ok(output)
388 }
389
390 #[tokio::test]
391 async fn accepts_session_and_receives_input_datagram() -> Result<()> {
392 let pair = loopback_pair().await?;
393 assert_eq!(pair.path, "/stream");
394
395 let input = InputEvent::PointerMove {
396 x: 0.25,
397 y: 0.75,
398 buttons: 1,
399 timestamp_ms: 4242,
400 };
401
402 pair.client_connection.send_datagram(&input.to_bytes())?;
403
404 let received = timeout(
405 Duration::from_secs(5),
406 pair.server_session.recv_input_datagram(),
407 )
408 .await??;
409
410 match received {
411 InputEvent::PointerMove {
412 x,
413 y,
414 buttons,
415 timestamp_ms,
416 } => {
417 assert!((x - 0.25).abs() < f32::EPSILON);
418 assert!((y - 0.75).abs() < f32::EPSILON);
419 assert_eq!(buttons, 1);
420 assert_eq!(timestamp_ms, 4242);
421 }
422 _ => panic!("expected pointer move"),
423 }
424
425 Ok(())
426 }
427
428 #[tokio::test]
429 async fn sends_frame_packets_over_datagrams() -> Result<()> {
430 let pair = loopback_pair().await?;
431
432 let packet = FramePacket {
433 frame_id: 7,
434 fragment_idx: 0,
435 fragment_count: 1,
436 timestamp_us: 123_456,
437 is_keyframe: true,
438 is_refine: false,
439 is_lossless: false,
440 payload: vec![1, 2, 3, 4, 5],
441 };
442
443 pair.server_session.send_frame_packet(&packet)?;
444
445 let datagram = timeout(
446 Duration::from_secs(5),
447 pair.client_connection.receive_datagram(),
448 )
449 .await??;
450 let decoded = FramePacket::from_bytes(datagram.as_ref())?;
451
452 assert_eq!(decoded.frame_id, 7);
453 assert_eq!(decoded.fragment_idx, 0);
454 assert_eq!(decoded.fragment_count, 1);
455 assert_eq!(decoded.timestamp_us, 123_456);
456 assert!(decoded.is_keyframe);
457 assert!(!decoded.is_refine);
458 assert!(!decoded.is_lossless);
459 assert_eq!(decoded.payload, vec![1, 2, 3, 4, 5]);
460
461 Ok(())
462 }
463
464 #[tokio::test]
465 async fn receives_reliable_input_and_sends_control_message() -> Result<()> {
466 let pair = loopback_pair().await?;
467
468 let mut send_stream = pair.client_connection.open_uni().await?.await?;
469 send_stream
470 .write_all(&InputEvent::KeyDown { code: 0x0041 }.to_bytes())
471 .await?;
472 drop(send_stream);
473
474 let received = timeout(
475 Duration::from_secs(5),
476 pair.server_session.recv_reliable_input(),
477 )
478 .await??;
479
480 match received {
481 InputEvent::KeyDown { code } => assert_eq!(code, 0x0041),
482 _ => panic!("expected key down"),
483 }
484
485 let control_message = b"codec=h265;mode=interactive";
486 pair.server_session
487 .send_control_message(control_message)
488 .await?;
489
490 let mut recv_stream =
491 timeout(Duration::from_secs(5), pair.client_connection.accept_uni()).await??;
492 let payload = read_client_stream(&mut recv_stream).await?;
493 assert_eq!(payload, control_message);
494
495 Ok(())
496 }
497}