webrtc_http/
server.rs

1use std::{io, sync::Arc};
2
3use dashmap::DashMap;
4use tokio::{
5  io::{simplex, AsyncWriteExt},
6  runtime::Handle,
7  sync::{
8    mpsc::{self, Receiver},
9    oneshot::channel,
10    Mutex,
11  },
12  task::block_in_place,
13};
14use webrtc::data_channel::{data_channel_message::DataChannelMessage, RTCDataChannel};
15
16use crate::util::{PollRTCDataChannel, RTCDataChannelStream, DEFAULT_READ_BUF_SIZE};
17
18#[derive(Clone)]
19pub struct RTCListener {
20  receiver: Arc<Mutex<Receiver<RTCDataChannelStream>>>,
21}
22
23unsafe impl Send for RTCListener {}
24unsafe impl Sync for RTCListener {}
25
26impl RTCListener {
27  pub fn new(data_channel: Arc<RTCDataChannel>) -> Self {
28    let (stream_sender, stream_receiver) = mpsc::channel(1024);
29    let stream_receiver = Arc::new(Mutex::new(stream_receiver));
30    let streams = Arc::new(DashMap::new());
31
32    let on_message_streams = streams.clone();
33    let on_message_data_channel = data_channel.clone();
34
35    data_channel.on_message(Box::new(move |mut msg: DataChannelMessage| {
36      if msg.data.len() < 4 {
37        eprintln!("received message with less than 4 bytes");
38        return Box::pin(async move {});
39      }
40      let stream_id_bytes = msg.data.split_to(4);
41      let mut stream_id_byte_slice: [u8; 4] = [0u8; 4];
42      stream_id_byte_slice.clone_from_slice(stream_id_bytes.as_ref());
43      let stream_id = u32::from_be_bytes(stream_id_byte_slice);
44
45      let pinned_streams = on_message_streams.clone();
46      let pinned_stream_sender = stream_sender.clone();
47      let pinned_data_channel = on_message_data_channel.clone();
48
49      Box::pin(async move {
50        let remove_streams = pinned_streams.clone();
51
52        let mut stream_sender = pinned_streams.entry(stream_id).or_insert_with(move || {
53          let transport_data_channel = pinned_data_channel.clone();
54          let mut max_message_size = if let Some(weak_transport) =
55            block_in_place(move || Handle::current().block_on(transport_data_channel.transport()))
56          {
57            if let Some(transport) = weak_transport.upgrade() {
58              transport.get_capabilities().max_message_size as usize
59            } else {
60              DEFAULT_READ_BUF_SIZE
61            }
62          } else {
63            DEFAULT_READ_BUF_SIZE
64          };
65          if max_message_size <= 4 {
66            max_message_size = DEFAULT_READ_BUF_SIZE - 4;
67          }
68
69          // TODO get max message size of transports
70          let (receiver, sender) = simplex(max_message_size);
71          let (shutdown_sender, shutdown_receiver) = channel();
72
73          tokio::task::spawn(async move {
74            match shutdown_receiver.await {
75              Ok(_) => {
76                remove_streams.remove(&stream_id);
77              }
78              Err(e) => {
79                eprintln!("error receiving shutdown: {}", e);
80              }
81            }
82          });
83
84          match block_in_place(move || {
85            let poll_rtc_data_channel = PollRTCDataChannel::new(stream_id, pinned_data_channel);
86            let rtc_data_channel_stream =
87              RTCDataChannelStream::new(poll_rtc_data_channel, receiver, shutdown_sender);
88            Handle::current().block_on(pinned_stream_sender.send(rtc_data_channel_stream))
89          }) {
90            Ok(_) => {}
91            Err(e) => {
92              panic!("error sending stream to accept: {}", e);
93            }
94          }
95
96          sender
97        });
98
99        match stream_sender.value_mut().write_all(msg.data.as_ref()).await {
100          Ok(_) => {}
101          Err(e) => {
102            panic!("error writing to stream: {}", e);
103          }
104        }
105      })
106    }));
107
108    let on_close_stream_sender = stream_receiver.clone();
109
110    data_channel.on_close(Box::new(move || {
111      let pinned_stream_sender = on_close_stream_sender.clone();
112
113      Box::pin(async move {
114        pinned_stream_sender.lock().await.close();
115      })
116    }));
117
118    Self {
119      receiver: stream_receiver,
120    }
121  }
122
123  pub async fn accept(&self) -> io::Result<RTCDataChannelStream> {
124    match self.receiver.lock().await.recv().await {
125      Some(stream) => Ok(stream),
126      None => Err(io::Error::new(
127        io::ErrorKind::Other,
128        "accept called on closed RTCListener",
129      )),
130    }
131  }
132
133  pub async fn close(&self) {
134    self.receiver.lock().await.close();
135  }
136}
137
138#[cfg(feature = "axum")]
139impl axum::serve::Listener for RTCListener {
140  type Addr = ();
141  type Io = RTCDataChannelStream;
142
143  async fn accept(&mut self) -> (Self::Io, Self::Addr) {
144    loop {
145      match Self::accept(self).await {
146        Ok(stream) => return (stream, ()),
147        Err(e) => {
148          eprintln!("error accepting stream: {}", e);
149          tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
150        }
151      }
152    }
153  }
154
155  fn local_addr(&self) -> io::Result<Self::Addr> {
156    Ok(())
157  }
158}