webrtc_http/
client.rs

1use std::sync::Arc;
2
3use dashmap::DashMap;
4use tokio::{
5  io::{simplex, AsyncWriteExt, SimplexStream, WriteHalf},
6  sync::oneshot::channel,
7};
8use webrtc::data_channel::{data_channel_message::DataChannelMessage, RTCDataChannel};
9
10use crate::util::{PollRTCDataChannel, RTCDataChannelStream, DEFAULT_READ_BUF_SIZE};
11
12#[derive(Clone)]
13pub struct RTCClient {
14  streams: Arc<DashMap<u32, WriteHalf<SimplexStream>>>,
15  data_channel: Arc<RTCDataChannel>,
16}
17
18unsafe impl Send for RTCClient {}
19unsafe impl Sync for RTCClient {}
20
21impl RTCClient {
22  pub fn new(data_channel: Arc<RTCDataChannel>) -> Self {
23    let streams: Arc<DashMap<u32, WriteHalf<SimplexStream>>> = Arc::new(DashMap::new());
24
25    let on_message_streams = streams.clone();
26
27    data_channel.on_message(Box::new(move |mut msg: DataChannelMessage| {
28      if msg.data.len() < 4 {
29        eprintln!("received message with less than 4 bytes");
30        return Box::pin(async move {});
31      }
32      let stream_id_bytes = msg.data.split_to(4);
33      let mut stream_id_byte_slice: [u8; 4] = [0u8; 4];
34      stream_id_byte_slice.clone_from_slice(stream_id_bytes.as_ref());
35      let stream_id = u32::from_be_bytes(stream_id_byte_slice);
36
37      let pinned_streams = on_message_streams.clone();
38
39      Box::pin(async move {
40        if let Some(mut stream_sender) = pinned_streams.get_mut(&stream_id) {
41          match stream_sender.value_mut().write_all(msg.data.as_ref()).await {
42            Ok(_) => {}
43            Err(e) => {
44              panic!("error writing to stream: {}", e);
45            }
46          }
47        } else {
48          eprintln!("received message for unknown stream");
49        }
50      })
51    }));
52
53    data_channel.on_close(Box::new(move || {
54      Box::pin(async move {
55        // TODO: do not allow any new connections after this
56      })
57    }));
58
59    Self {
60      streams,
61      data_channel,
62    }
63  }
64
65  pub async fn connect(&self) -> RTCDataChannelStream {
66    let data_channel = self.data_channel.clone();
67    let mut max_message_size = if let Some(weak_transport) = data_channel.transport().await {
68      if let Some(transport) = weak_transport.upgrade() {
69        transport.get_capabilities().max_message_size as usize
70      } else {
71        DEFAULT_READ_BUF_SIZE
72      }
73    } else {
74      DEFAULT_READ_BUF_SIZE
75    };
76    if max_message_size <= 4 {
77      max_message_size = DEFAULT_READ_BUF_SIZE - 4;
78    }
79
80    let (receiver, sender) = simplex(max_message_size);
81    let stream_id = rand::random::<u32>();
82    let (shutdown_sender, shutdown_receiver) = channel();
83
84    let remove_streams = self.streams.clone();
85    tokio::task::spawn(async move {
86      match shutdown_receiver.await {
87        Ok(_) => {
88          remove_streams.remove(&stream_id);
89        }
90        Err(e) => {
91          eprintln!("error receiving shutdown: {}", e);
92        }
93      }
94    });
95
96    let poll_data_channel = PollRTCDataChannel::new(stream_id, self.data_channel.clone());
97    let stream = RTCDataChannelStream::new(poll_data_channel, receiver, shutdown_sender);
98    self.streams.insert(stream_id, sender);
99
100    stream
101  }
102}