1use futures::{future::Either, pin_mut};
2use std::{
3 collections::BTreeMap,
4 error::Error,
5 io::{Cursor, ErrorKind},
6 time::{Duration, Instant},
7};
8use tehuti::{
9 channel::ChannelMode,
10 engine::{
11 EngineId, EngineMeeting, EngineMeetingConfig, EngineMeetingEvent, EngineMeetingResult,
12 },
13 event::Duplex,
14 protocol::ProtocolFrame,
15};
16
17const QUIC_DATAGRAM_PAYLOAD_SIZE: usize = 1000;
18const FRAGMENT_HEADER_SIZE: usize = 9;
19const DATAGRAM_DATA_SIZE: usize = QUIC_DATAGRAM_PAYLOAD_SIZE - FRAGMENT_HEADER_SIZE;
20
21pub type WebTransportMeetingEvent = EngineMeetingEvent;
22pub type WebTransportMeetingResult = EngineMeetingResult;
23pub type WebTransportMeeting = EngineMeeting;
24pub type WebTransportMeetingConfig = EngineMeetingConfig;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum WebTransportSessionRole {
28 Initiator,
29 Responder,
30}
31
32pub struct WebTransportSessionResult {
33 pub session: WebTransportSession,
34 pub frames: Duplex<ProtocolFrame>,
35}
36
37pub struct WebTransportSession {
38 session: web_transport::Session,
39 send_stream: web_transport::SendStream,
40 recv_stream: web_transport::RecvStream,
41 local_engine_id: EngineId,
42 remote_engine_id: EngineId,
43 buffer_in: Vec<u8>,
44 buffer_datagrams: Vec<Vec<u8>>,
45 frames: Duplex<ProtocolFrame>,
46 pub log_frames: bool,
47 next_frame_id: u32,
48 fragment_buffer: BTreeMap<u32, (Instant, BTreeMap<u16, Vec<u8>>)>,
49}
50
51impl WebTransportSession {
52 pub async fn make(
53 session: web_transport::Session,
54 local_engine_id: EngineId,
55 role: WebTransportSessionRole,
56 ) -> Result<WebTransportSessionResult, Box<dyn Error>> {
57 let (mut send_stream, mut recv_stream) = match role {
58 WebTransportSessionRole::Initiator => session.open_bi().await?,
59 WebTransportSessionRole::Responder => session.accept_bi().await?,
60 };
61
62 let remote_engine_id = match role {
63 WebTransportSessionRole::Initiator => {
64 Self::write_engine_id(&mut send_stream, local_engine_id).await?;
65 Self::read_engine_id(&mut recv_stream).await?
66 }
67 WebTransportSessionRole::Responder => {
68 let remote_engine_id = Self::read_engine_id(&mut recv_stream).await?;
69 Self::write_engine_id(&mut send_stream, local_engine_id).await?;
70 remote_engine_id
71 }
72 };
73
74 let (frames_inside, frames_outside) = Duplex::crossing_unbounded();
75 tracing::event!(
76 target: "tehuti::web_transport::session",
77 tracing::Level::TRACE,
78 "Session created. Local engine ID: {:?}, remote engine ID: {:?}",
79 local_engine_id,
80 remote_engine_id,
81 );
82
83 Ok(WebTransportSessionResult {
84 session: Self {
85 session,
86 send_stream,
87 recv_stream,
88 local_engine_id,
89 remote_engine_id,
90 buffer_in: Vec::new(),
91 buffer_datagrams: Vec::new(),
92 frames: frames_inside,
93 log_frames: false,
94 next_frame_id: 1,
95 fragment_buffer: BTreeMap::new(),
96 },
97 frames: frames_outside,
98 })
99 }
100
101 pub fn log_frames(mut self, value: bool) -> Self {
102 self.log_frames = value;
103 self
104 }
105
106 pub fn local_engine_id(&self) -> EngineId {
107 self.local_engine_id
108 }
109
110 pub fn remote_engine_id(&self) -> EngineId {
111 self.remote_engine_id
112 }
113
114 pub fn close(&self, code: u32, reason: &str) {
115 self.session.close(code, reason);
116 }
117
118 pub async fn maintain(&mut self) -> Result<(), Box<dyn Error>> {
119 self.send_frames().await?;
120 self.receive_frames().await?;
121 Ok(())
122 }
123
124 pub async fn into_future(mut self) -> Result<(), Box<dyn Error>> {
125 loop {
126 let result = async {
127 self.send_frames().await?;
128 self.receive_datagrams().await;
129
130 let inbound: Option<Vec<u8>> = {
131 let read = self.recv_stream.read(65536);
132 let timeout = futures_timer::Delay::new(Duration::from_millis(5));
133 pin_mut!(read);
134 pin_mut!(timeout);
135
136 match futures::future::select(read, timeout).await {
137 Either::Left((chunk, _)) => match chunk? {
138 Some(chunk) => Some(chunk.to_vec()),
139 None => return Err("WebTransport receive stream closed".into()),
140 },
141 Either::Right((_, _)) => None,
142 }
143 };
144
145 if let Some(chunk) = inbound {
146 self.receive_chunk(chunk.as_ref())?;
147 }
148
149 self.process_buffered_datagrams()?;
150
151 Ok(())
152 }
153 .await;
154
155 if let Err(err) = result {
156 tracing::event!(
157 target: "tehuti::web_transport::session",
158 tracing::Level::ERROR,
159 "Session {:?}<->{:?} terminated with error: {}",
160 self.local_engine_id,
161 self.remote_engine_id,
162 err,
163 );
164 return Err(err);
165 }
166 }
167 }
168
169 async fn receive_frames(&mut self) -> Result<(), Box<dyn Error>> {
170 let Some(chunk) = self.recv_stream.read(65536).await? else {
171 return Err("WebTransport receive stream closed".into());
172 };
173
174 self.receive_chunk(chunk.as_ref())
175 }
176
177 fn receive_chunk(&mut self, chunk: &[u8]) -> Result<(), Box<dyn Error>> {
178 if chunk.is_empty() {
179 return Ok(());
180 }
181
182 self.buffer_in.extend_from_slice(chunk);
183
184 if self.log_frames {
185 tracing::event!(
186 target: "tehuti::web_transport::session",
187 tracing::Level::TRACE,
188 "Session {:?}<->{:?} received {} bytes",
189 self.local_engine_id,
190 self.remote_engine_id,
191 chunk.len(),
192 );
193 }
194
195 loop {
196 if self.buffer_in.len() < 4 {
197 break;
198 }
199
200 let mut size_bytes = [0u8; 4];
201 size_bytes.copy_from_slice(&self.buffer_in[..4]);
202 let frame_size = u32::from_le_bytes(size_bytes) as usize;
203
204 if self.buffer_in.len() < 4 + frame_size {
205 break;
206 }
207
208 let payload = self.buffer_in[4..(4 + frame_size)].to_vec();
209 self.buffer_in.drain(..(4 + frame_size));
210
211 let mut cursor = Cursor::new(payload.as_slice());
212 let mut frame = ProtocolFrame::read(&mut cursor)?;
213
214 if let ProtocolFrame::Packet(frame) = &mut frame
215 && frame.data.sender.is_none()
216 {
217 frame.data.sender = Some(self.remote_engine_id);
218 }
219
220 self.frames.sender.send(frame).map_err(|err| {
221 format!(
222 "Session {:?}<->{:?} frame sender error: {}",
223 self.local_engine_id, self.remote_engine_id, err
224 )
225 })?;
226 }
227
228 Ok(())
229 }
230
231 async fn send_frame(&mut self, mut frame: ProtocolFrame) -> Result<(), Box<dyn Error>> {
232 if let ProtocolFrame::Packet(frame) = &mut frame {
233 frame.data.sender = Some(self.local_engine_id);
234 }
235
236 let mut payload = Vec::new();
237 frame.write(&mut payload)?;
238
239 if self.log_frames {
240 tracing::event!(
241 target: "tehuti::web_transport::session",
242 tracing::Level::TRACE,
243 "Session {:?}<->{:?} writing frame: {:?}",
244 self.local_engine_id,
245 self.remote_engine_id,
246 frame,
247 );
248 }
249
250 if let ProtocolFrame::Packet(pkt) = &frame
251 && pkt.channel_mode == ChannelMode::Unreliable
252 {
253 self.send_unreliable_frame(payload).await?;
254 return Ok(());
255 }
256 let payload_size = payload.len() as u32;
257 self.write_all(&payload_size.to_le_bytes()).await?;
258 self.write_all(&payload).await?;
259 Ok(())
260 }
261
262 async fn send_frames(&mut self) -> Result<(), Box<dyn Error>> {
263 let frames: Vec<_> = self.frames.receiver.iter().collect();
264
265 for frame in frames {
266 self.send_frame(frame).await?;
267 }
268
269 Ok(())
270 }
271
272 async fn send_unreliable_frame(&mut self, payload: Vec<u8>) -> Result<(), Box<dyn Error>> {
273 let frame_id = self.next_frame_id;
274 self.next_frame_id = self.next_frame_id.wrapping_add(1);
275
276 let total_fragments = payload.len().div_ceil(DATAGRAM_DATA_SIZE) as u16;
277
278 for (frag_idx, chunk) in payload.chunks(DATAGRAM_DATA_SIZE).enumerate() {
279 let mut datagram_data = Vec::with_capacity(FRAGMENT_HEADER_SIZE + chunk.len());
280
281 datagram_data.extend_from_slice(&frame_id.to_le_bytes());
282 datagram_data.extend_from_slice(&(frag_idx as u16).to_le_bytes());
283 datagram_data.extend_from_slice(&total_fragments.to_le_bytes());
284 datagram_data.push(0);
285
286 datagram_data.extend_from_slice(chunk);
287 if let Err(err) = self.session.send_datagram(datagram_data.into()).await {
288 tracing::event!(
289 target: "tehuti::web_transport::session",
290 tracing::Level::WARN,
291 "Session {:?}<->{:?} failed to send datagram fragment {}/{}: {}",
292 self.local_engine_id,
293 self.remote_engine_id,
294 frag_idx + 1,
295 total_fragments,
296 err,
297 );
298 }
299 }
300
301 Ok(())
302 }
303
304 async fn receive_datagrams(&mut self) {
305 let datagram = {
306 let recv = self.session.recv_datagram();
307 let timeout = futures_timer::Delay::new(Duration::from_millis(1));
308 pin_mut!(recv);
309 pin_mut!(timeout);
310
311 match futures::future::select(recv, timeout).await {
312 Either::Left((Ok(bytes), _)) => Some(bytes.to_vec()),
313 Either::Left((Err(_), _)) => None,
314 Either::Right((_, _)) => None,
315 }
316 };
317
318 if let Some(datagram) = datagram {
319 self.buffer_datagrams.push(datagram);
320 }
321 }
322
323 fn process_buffered_datagrams(&mut self) -> Result<(), Box<dyn Error>> {
324 let now = Instant::now();
325 self.fragment_buffer.retain(|_, (received_at, _)| {
326 now.duration_since(*received_at) < Duration::from_secs(5)
327 });
328
329 for datagram in self.buffer_datagrams.drain(..) {
330 if datagram.len() < FRAGMENT_HEADER_SIZE {
331 tracing::event!(
332 target: "tehuti::web_transport::session",
333 tracing::Level::WARN,
334 "Session {:?}<->{:?} received datagram smaller than header",
335 self.local_engine_id,
336 self.remote_engine_id,
337 );
338 continue;
339 }
340
341 let frame_id = u32::from_le_bytes([datagram[0], datagram[1], datagram[2], datagram[3]]);
342 let frag_idx = u16::from_le_bytes([datagram[4], datagram[5]]) as usize;
343 let total_frags = u16::from_le_bytes([datagram[6], datagram[7]]) as usize;
344
345 let data = datagram[FRAGMENT_HEADER_SIZE..].to_vec();
346 let (_, fragments) = self
347 .fragment_buffer
348 .entry(frame_id)
349 .or_insert_with(|| (Instant::now(), BTreeMap::new()));
350
351 fragments.insert(frag_idx as u16, data);
352
353 if fragments.len() == total_frags
354 && fragments.keys().max().copied() == Some((total_frags - 1) as u16)
355 {
356 let (_, fragments) = self.fragment_buffer.remove(&frame_id).unwrap();
357
358 let mut payload = Vec::new();
359 for i in 0..total_frags {
360 if let Some(data) = fragments.get(&(i as u16)) {
361 payload.extend_from_slice(data);
362 }
363 }
364 let mut cursor = Cursor::new(payload.as_slice());
365 let mut frame = ProtocolFrame::read(&mut cursor)?;
366
367 if let ProtocolFrame::Packet(frame) = &mut frame
368 && frame.data.sender.is_none()
369 {
370 frame.data.sender = Some(self.remote_engine_id);
371 }
372
373 self.frames.sender.send(frame).map_err(|err| {
374 format!(
375 "Session {:?}<->{:?} frame sender error: {}",
376 self.local_engine_id, self.remote_engine_id, err
377 )
378 })?;
379 }
380 }
381
382 Ok(())
383 }
384
385 async fn write_all(&mut self, mut buffer: &[u8]) -> Result<(), Box<dyn Error>> {
386 while !buffer.is_empty() {
387 let wrote = self.send_stream.write(buffer).await?;
388 if wrote == 0 {
389 return Err("WebTransport send stream produced zero-byte write".into());
390 }
391 buffer = &buffer[wrote..];
392 }
393 Ok(())
394 }
395
396 async fn write_engine_id(
397 send_stream: &mut web_transport::SendStream,
398 engine_id: EngineId,
399 ) -> Result<(), Box<dyn Error>> {
400 let mut data = engine_id.id().to_le_bytes().to_vec();
401 while !data.is_empty() {
402 let wrote = send_stream.write(data.as_slice()).await?;
403 if wrote == 0 {
404 return Err("WebTransport handshake write failed".into());
405 }
406 data.drain(..wrote);
407 }
408 Ok(())
409 }
410
411 async fn read_engine_id(
412 recv_stream: &mut web_transport::RecvStream,
413 ) -> Result<EngineId, Box<dyn Error>> {
414 let mut data = Vec::with_capacity(16);
415 while data.len() < 16 {
416 let Some(chunk) = recv_stream.read(16 - data.len()).await? else {
417 return Err("WebTransport handshake stream closed".into());
418 };
419 if chunk.is_empty() {
420 return Err("WebTransport handshake stream produced empty read".into());
421 }
422 data.extend_from_slice(chunk.as_ref());
423 }
424
425 if data.len() != 16 {
426 return Err(std::io::Error::new(
427 ErrorKind::InvalidData,
428 format!("Invalid EngineId handshake size: {}", data.len()),
429 )
430 .into());
431 }
432
433 let mut engine_id_bytes = [0u8; 16];
434 engine_id_bytes.copy_from_slice(&data);
435 Ok(EngineId::new(u128::from_le_bytes(engine_id_bytes)))
436 }
437}
438
439#[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))]
440pub struct WebTransportHost {
441 server: web_transport::Server,
442 local_engine_id: EngineId,
443 pub log_frames: bool,
444}
445
446#[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))]
447impl WebTransportHost {
448 pub fn new(server: web_transport::Server, local_engine_id: EngineId) -> Self {
449 tracing::event!(
450 target: "tehuti::web_transport::host",
451 tracing::Level::TRACE,
452 "WebTransportHost created. Local engine ID: {:?}",
453 local_engine_id,
454 );
455 Self {
456 server,
457 local_engine_id,
458 log_frames: false,
459 }
460 }
461
462 pub fn log_frames(mut self, value: bool) -> Self {
463 self.log_frames = value;
464 self
465 }
466
467 pub fn local_engine_id(&self) -> EngineId {
468 self.local_engine_id
469 }
470
471 pub async fn accept(&mut self) -> Result<Option<WebTransportSessionResult>, Box<dyn Error>> {
472 let Some(session) = self.server.accept().await? else {
473 return Ok(None);
474 };
475
476 let mut session_result = WebTransportSession::make(
477 session,
478 self.local_engine_id,
479 WebTransportSessionRole::Responder,
480 )
481 .await?;
482 session_result.session.log_frames = self.log_frames;
483
484 Ok(Some(session_result))
485 }
486}