Skip to main content

tehuti_web_transport/
lib.rs

1use futures::{future::Either, pin_mut};
2use std::{
3    collections::BTreeMap,
4    error::Error,
5    io::{Cursor, ErrorKind},
6    time::{Duration, Instant},
7};
8use tehuti::{
9    channel::ChannelMode,
10    engine::{
11        EngineId, EngineMeeting, EngineMeetingConfig, EngineMeetingEvent, EngineMeetingResult,
12    },
13    event::Duplex,
14    protocol::ProtocolFrame,
15};
16
17const QUIC_DATAGRAM_PAYLOAD_SIZE: usize = 1000;
18const FRAGMENT_HEADER_SIZE: usize = 9;
19const DATAGRAM_DATA_SIZE: usize = QUIC_DATAGRAM_PAYLOAD_SIZE - FRAGMENT_HEADER_SIZE;
20
21pub type WebTransportMeetingEvent = EngineMeetingEvent;
22pub type WebTransportMeetingResult = EngineMeetingResult;
23pub type WebTransportMeeting = EngineMeeting;
24pub type WebTransportMeetingConfig = EngineMeetingConfig;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum WebTransportSessionRole {
28    Initiator,
29    Responder,
30}
31
32pub struct WebTransportSessionResult {
33    pub session: WebTransportSession,
34    pub frames: Duplex<ProtocolFrame>,
35}
36
37pub struct WebTransportSession {
38    session: web_transport::Session,
39    send_stream: web_transport::SendStream,
40    recv_stream: web_transport::RecvStream,
41    local_engine_id: EngineId,
42    remote_engine_id: EngineId,
43    buffer_in: Vec<u8>,
44    buffer_datagrams: Vec<Vec<u8>>,
45    frames: Duplex<ProtocolFrame>,
46    pub log_frames: bool,
47    next_frame_id: u32,
48    fragment_buffer: BTreeMap<u32, (Instant, BTreeMap<u16, Vec<u8>>)>,
49}
50
51impl WebTransportSession {
52    pub async fn make(
53        session: web_transport::Session,
54        local_engine_id: EngineId,
55        role: WebTransportSessionRole,
56    ) -> Result<WebTransportSessionResult, Box<dyn Error>> {
57        let (mut send_stream, mut recv_stream) = match role {
58            WebTransportSessionRole::Initiator => session.open_bi().await?,
59            WebTransportSessionRole::Responder => session.accept_bi().await?,
60        };
61
62        let remote_engine_id = match role {
63            WebTransportSessionRole::Initiator => {
64                Self::write_engine_id(&mut send_stream, local_engine_id).await?;
65                Self::read_engine_id(&mut recv_stream).await?
66            }
67            WebTransportSessionRole::Responder => {
68                let remote_engine_id = Self::read_engine_id(&mut recv_stream).await?;
69                Self::write_engine_id(&mut send_stream, local_engine_id).await?;
70                remote_engine_id
71            }
72        };
73
74        let (frames_inside, frames_outside) = Duplex::crossing_unbounded();
75        tracing::event!(
76            target: "tehuti::web_transport::session",
77            tracing::Level::TRACE,
78            "Session created. Local engine ID: {:?}, remote engine ID: {:?}",
79            local_engine_id,
80            remote_engine_id,
81        );
82
83        Ok(WebTransportSessionResult {
84            session: Self {
85                session,
86                send_stream,
87                recv_stream,
88                local_engine_id,
89                remote_engine_id,
90                buffer_in: Vec::new(),
91                buffer_datagrams: Vec::new(),
92                frames: frames_inside,
93                log_frames: false,
94                next_frame_id: 1,
95                fragment_buffer: BTreeMap::new(),
96            },
97            frames: frames_outside,
98        })
99    }
100
101    pub fn log_frames(mut self, value: bool) -> Self {
102        self.log_frames = value;
103        self
104    }
105
106    pub fn local_engine_id(&self) -> EngineId {
107        self.local_engine_id
108    }
109
110    pub fn remote_engine_id(&self) -> EngineId {
111        self.remote_engine_id
112    }
113
114    pub fn close(&self, code: u32, reason: &str) {
115        self.session.close(code, reason);
116    }
117
118    pub async fn maintain(&mut self) -> Result<(), Box<dyn Error>> {
119        self.send_frames().await?;
120        self.receive_frames().await?;
121        Ok(())
122    }
123
124    pub async fn into_future(mut self) -> Result<(), Box<dyn Error>> {
125        loop {
126            let result = async {
127                self.send_frames().await?;
128                self.receive_datagrams().await;
129
130                let inbound: Option<Vec<u8>> = {
131                    let read = self.recv_stream.read(65536);
132                    let timeout = futures_timer::Delay::new(Duration::from_millis(5));
133                    pin_mut!(read);
134                    pin_mut!(timeout);
135
136                    match futures::future::select(read, timeout).await {
137                        Either::Left((chunk, _)) => match chunk? {
138                            Some(chunk) => Some(chunk.to_vec()),
139                            None => return Err("WebTransport receive stream closed".into()),
140                        },
141                        Either::Right((_, _)) => None,
142                    }
143                };
144
145                if let Some(chunk) = inbound {
146                    self.receive_chunk(chunk.as_ref())?;
147                }
148
149                self.process_buffered_datagrams()?;
150
151                Ok(())
152            }
153            .await;
154
155            if let Err(err) = result {
156                tracing::event!(
157                    target: "tehuti::web_transport::session",
158                    tracing::Level::ERROR,
159                    "Session {:?}<->{:?} terminated with error: {}",
160                    self.local_engine_id,
161                    self.remote_engine_id,
162                    err,
163                );
164                return Err(err);
165            }
166        }
167    }
168
169    async fn receive_frames(&mut self) -> Result<(), Box<dyn Error>> {
170        let Some(chunk) = self.recv_stream.read(65536).await? else {
171            return Err("WebTransport receive stream closed".into());
172        };
173
174        self.receive_chunk(chunk.as_ref())
175    }
176
177    fn receive_chunk(&mut self, chunk: &[u8]) -> Result<(), Box<dyn Error>> {
178        if chunk.is_empty() {
179            return Ok(());
180        }
181
182        self.buffer_in.extend_from_slice(chunk);
183
184        if self.log_frames {
185            tracing::event!(
186                target: "tehuti::web_transport::session",
187                tracing::Level::TRACE,
188                "Session {:?}<->{:?} received {} bytes",
189                self.local_engine_id,
190                self.remote_engine_id,
191                chunk.len(),
192            );
193        }
194
195        loop {
196            if self.buffer_in.len() < 4 {
197                break;
198            }
199
200            let mut size_bytes = [0u8; 4];
201            size_bytes.copy_from_slice(&self.buffer_in[..4]);
202            let frame_size = u32::from_le_bytes(size_bytes) as usize;
203
204            if self.buffer_in.len() < 4 + frame_size {
205                break;
206            }
207
208            let payload = self.buffer_in[4..(4 + frame_size)].to_vec();
209            self.buffer_in.drain(..(4 + frame_size));
210
211            let mut cursor = Cursor::new(payload.as_slice());
212            let mut frame = ProtocolFrame::read(&mut cursor)?;
213
214            if let ProtocolFrame::Packet(frame) = &mut frame
215                && frame.data.sender.is_none()
216            {
217                frame.data.sender = Some(self.remote_engine_id);
218            }
219
220            self.frames.sender.send(frame).map_err(|err| {
221                format!(
222                    "Session {:?}<->{:?} frame sender error: {}",
223                    self.local_engine_id, self.remote_engine_id, err
224                )
225            })?;
226        }
227
228        Ok(())
229    }
230
231    async fn send_frame(&mut self, mut frame: ProtocolFrame) -> Result<(), Box<dyn Error>> {
232        if let ProtocolFrame::Packet(frame) = &mut frame {
233            frame.data.sender = Some(self.local_engine_id);
234        }
235
236        let mut payload = Vec::new();
237        frame.write(&mut payload)?;
238
239        if self.log_frames {
240            tracing::event!(
241                target: "tehuti::web_transport::session",
242                tracing::Level::TRACE,
243                "Session {:?}<->{:?} writing frame: {:?}",
244                self.local_engine_id,
245                self.remote_engine_id,
246                frame,
247            );
248        }
249
250        if let ProtocolFrame::Packet(pkt) = &frame
251            && pkt.channel_mode == ChannelMode::Unreliable
252        {
253            self.send_unreliable_frame(payload).await?;
254            return Ok(());
255        }
256        let payload_size = payload.len() as u32;
257        self.write_all(&payload_size.to_le_bytes()).await?;
258        self.write_all(&payload).await?;
259        Ok(())
260    }
261
262    async fn send_frames(&mut self) -> Result<(), Box<dyn Error>> {
263        let frames: Vec<_> = self.frames.receiver.iter().collect();
264
265        for frame in frames {
266            self.send_frame(frame).await?;
267        }
268
269        Ok(())
270    }
271
272    async fn send_unreliable_frame(&mut self, payload: Vec<u8>) -> Result<(), Box<dyn Error>> {
273        let frame_id = self.next_frame_id;
274        self.next_frame_id = self.next_frame_id.wrapping_add(1);
275
276        let total_fragments = payload.len().div_ceil(DATAGRAM_DATA_SIZE) as u16;
277
278        for (frag_idx, chunk) in payload.chunks(DATAGRAM_DATA_SIZE).enumerate() {
279            let mut datagram_data = Vec::with_capacity(FRAGMENT_HEADER_SIZE + chunk.len());
280
281            datagram_data.extend_from_slice(&frame_id.to_le_bytes());
282            datagram_data.extend_from_slice(&(frag_idx as u16).to_le_bytes());
283            datagram_data.extend_from_slice(&total_fragments.to_le_bytes());
284            datagram_data.push(0);
285
286            datagram_data.extend_from_slice(chunk);
287            if let Err(err) = self.session.send_datagram(datagram_data.into()).await {
288                tracing::event!(
289                    target: "tehuti::web_transport::session",
290                    tracing::Level::WARN,
291                    "Session {:?}<->{:?} failed to send datagram fragment {}/{}: {}",
292                    self.local_engine_id,
293                    self.remote_engine_id,
294                    frag_idx + 1,
295                    total_fragments,
296                    err,
297                );
298            }
299        }
300
301        Ok(())
302    }
303
304    async fn receive_datagrams(&mut self) {
305        let datagram = {
306            let recv = self.session.recv_datagram();
307            let timeout = futures_timer::Delay::new(Duration::from_millis(1));
308            pin_mut!(recv);
309            pin_mut!(timeout);
310
311            match futures::future::select(recv, timeout).await {
312                Either::Left((Ok(bytes), _)) => Some(bytes.to_vec()),
313                Either::Left((Err(_), _)) => None,
314                Either::Right((_, _)) => None,
315            }
316        };
317
318        if let Some(datagram) = datagram {
319            self.buffer_datagrams.push(datagram);
320        }
321    }
322
323    fn process_buffered_datagrams(&mut self) -> Result<(), Box<dyn Error>> {
324        let now = Instant::now();
325        self.fragment_buffer.retain(|_, (received_at, _)| {
326            now.duration_since(*received_at) < Duration::from_secs(5)
327        });
328
329        for datagram in self.buffer_datagrams.drain(..) {
330            if datagram.len() < FRAGMENT_HEADER_SIZE {
331                tracing::event!(
332                    target: "tehuti::web_transport::session",
333                    tracing::Level::WARN,
334                    "Session {:?}<->{:?} received datagram smaller than header",
335                    self.local_engine_id,
336                    self.remote_engine_id,
337                );
338                continue;
339            }
340
341            let frame_id = u32::from_le_bytes([datagram[0], datagram[1], datagram[2], datagram[3]]);
342            let frag_idx = u16::from_le_bytes([datagram[4], datagram[5]]) as usize;
343            let total_frags = u16::from_le_bytes([datagram[6], datagram[7]]) as usize;
344
345            let data = datagram[FRAGMENT_HEADER_SIZE..].to_vec();
346            let (_, fragments) = self
347                .fragment_buffer
348                .entry(frame_id)
349                .or_insert_with(|| (Instant::now(), BTreeMap::new()));
350
351            fragments.insert(frag_idx as u16, data);
352
353            if fragments.len() == total_frags
354                && fragments.keys().max().copied() == Some((total_frags - 1) as u16)
355            {
356                let (_, fragments) = self.fragment_buffer.remove(&frame_id).unwrap();
357
358                let mut payload = Vec::new();
359                for i in 0..total_frags {
360                    if let Some(data) = fragments.get(&(i as u16)) {
361                        payload.extend_from_slice(data);
362                    }
363                }
364                let mut cursor = Cursor::new(payload.as_slice());
365                let mut frame = ProtocolFrame::read(&mut cursor)?;
366
367                if let ProtocolFrame::Packet(frame) = &mut frame
368                    && frame.data.sender.is_none()
369                {
370                    frame.data.sender = Some(self.remote_engine_id);
371                }
372
373                self.frames.sender.send(frame).map_err(|err| {
374                    format!(
375                        "Session {:?}<->{:?} frame sender error: {}",
376                        self.local_engine_id, self.remote_engine_id, err
377                    )
378                })?;
379            }
380        }
381
382        Ok(())
383    }
384
385    async fn write_all(&mut self, mut buffer: &[u8]) -> Result<(), Box<dyn Error>> {
386        while !buffer.is_empty() {
387            let wrote = self.send_stream.write(buffer).await?;
388            if wrote == 0 {
389                return Err("WebTransport send stream produced zero-byte write".into());
390            }
391            buffer = &buffer[wrote..];
392        }
393        Ok(())
394    }
395
396    async fn write_engine_id(
397        send_stream: &mut web_transport::SendStream,
398        engine_id: EngineId,
399    ) -> Result<(), Box<dyn Error>> {
400        let mut data = engine_id.id().to_le_bytes().to_vec();
401        while !data.is_empty() {
402            let wrote = send_stream.write(data.as_slice()).await?;
403            if wrote == 0 {
404                return Err("WebTransport handshake write failed".into());
405            }
406            data.drain(..wrote);
407        }
408        Ok(())
409    }
410
411    async fn read_engine_id(
412        recv_stream: &mut web_transport::RecvStream,
413    ) -> Result<EngineId, Box<dyn Error>> {
414        let mut data = Vec::with_capacity(16);
415        while data.len() < 16 {
416            let Some(chunk) = recv_stream.read(16 - data.len()).await? else {
417                return Err("WebTransport handshake stream closed".into());
418            };
419            if chunk.is_empty() {
420                return Err("WebTransport handshake stream produced empty read".into());
421            }
422            data.extend_from_slice(chunk.as_ref());
423        }
424
425        if data.len() != 16 {
426            return Err(std::io::Error::new(
427                ErrorKind::InvalidData,
428                format!("Invalid EngineId handshake size: {}", data.len()),
429            )
430            .into());
431        }
432
433        let mut engine_id_bytes = [0u8; 16];
434        engine_id_bytes.copy_from_slice(&data);
435        Ok(EngineId::new(u128::from_le_bytes(engine_id_bytes)))
436    }
437}
438
439#[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))]
440pub struct WebTransportHost {
441    server: web_transport::Server,
442    local_engine_id: EngineId,
443    pub log_frames: bool,
444}
445
446#[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))]
447impl WebTransportHost {
448    pub fn new(server: web_transport::Server, local_engine_id: EngineId) -> Self {
449        tracing::event!(
450            target: "tehuti::web_transport::host",
451            tracing::Level::TRACE,
452            "WebTransportHost created. Local engine ID: {:?}",
453            local_engine_id,
454        );
455        Self {
456            server,
457            local_engine_id,
458            log_frames: false,
459        }
460    }
461
462    pub fn log_frames(mut self, value: bool) -> Self {
463        self.log_frames = value;
464        self
465    }
466
467    pub fn local_engine_id(&self) -> EngineId {
468        self.local_engine_id
469    }
470
471    pub async fn accept(&mut self) -> Result<Option<WebTransportSessionResult>, Box<dyn Error>> {
472        let Some(session) = self.server.accept().await? else {
473            return Ok(None);
474        };
475
476        let mut session_result = WebTransportSession::make(
477            session,
478            self.local_engine_id,
479            WebTransportSessionRole::Responder,
480        )
481        .await?;
482        session_result.session.log_frames = self.log_frames;
483
484        Ok(Some(session_result))
485    }
486}