1use std::{io, sync::Arc};
2
3use dashmap::DashMap;
4use tokio::{
5 io::{simplex, AsyncWriteExt},
6 runtime::Handle,
7 sync::{
8 mpsc::{self, Receiver},
9 oneshot::channel,
10 Mutex,
11 },
12 task::block_in_place,
13};
14use webrtc::data_channel::{data_channel_message::DataChannelMessage, RTCDataChannel};
15
16use crate::util::{PollRTCDataChannel, RTCDataChannelStream, DEFAULT_READ_BUF_SIZE};
17
18#[derive(Clone)]
19pub struct RTCListener {
20 receiver: Arc<Mutex<Receiver<RTCDataChannelStream>>>,
21}
22
23unsafe impl Send for RTCListener {}
24unsafe impl Sync for RTCListener {}
25
26impl RTCListener {
27 pub fn new(data_channel: Arc<RTCDataChannel>) -> Self {
28 let (stream_sender, stream_receiver) = mpsc::channel(1024);
29 let stream_receiver = Arc::new(Mutex::new(stream_receiver));
30 let streams = Arc::new(DashMap::new());
31
32 let on_message_streams = streams.clone();
33 let on_message_data_channel = data_channel.clone();
34
35 data_channel.on_message(Box::new(move |mut msg: DataChannelMessage| {
36 if msg.data.len() < 4 {
37 eprintln!("received message with less than 4 bytes");
38 return Box::pin(async move {});
39 }
40 let stream_id_bytes = msg.data.split_to(4);
41 let mut stream_id_byte_slice: [u8; 4] = [0u8; 4];
42 stream_id_byte_slice.clone_from_slice(stream_id_bytes.as_ref());
43 let stream_id = u32::from_be_bytes(stream_id_byte_slice);
44
45 let pinned_streams = on_message_streams.clone();
46 let pinned_stream_sender = stream_sender.clone();
47 let pinned_data_channel = on_message_data_channel.clone();
48
49 Box::pin(async move {
50 let remove_streams = pinned_streams.clone();
51
52 let mut stream_sender = pinned_streams.entry(stream_id).or_insert_with(move || {
53 let transport_data_channel = pinned_data_channel.clone();
54 let mut max_message_size = if let Some(weak_transport) =
55 block_in_place(move || Handle::current().block_on(transport_data_channel.transport()))
56 {
57 if let Some(transport) = weak_transport.upgrade() {
58 transport.get_capabilities().max_message_size as usize
59 } else {
60 DEFAULT_READ_BUF_SIZE
61 }
62 } else {
63 DEFAULT_READ_BUF_SIZE
64 };
65 if max_message_size <= 4 {
66 max_message_size = DEFAULT_READ_BUF_SIZE - 4;
67 }
68
69 let (receiver, sender) = simplex(max_message_size);
71 let (shutdown_sender, shutdown_receiver) = channel();
72
73 tokio::task::spawn(async move {
74 match shutdown_receiver.await {
75 Ok(_) => {
76 remove_streams.remove(&stream_id);
77 }
78 Err(e) => {
79 eprintln!("error receiving shutdown: {}", e);
80 }
81 }
82 });
83
84 match block_in_place(move || {
85 let poll_rtc_data_channel = PollRTCDataChannel::new(stream_id, pinned_data_channel);
86 let rtc_data_channel_stream =
87 RTCDataChannelStream::new(poll_rtc_data_channel, receiver, shutdown_sender);
88 Handle::current().block_on(pinned_stream_sender.send(rtc_data_channel_stream))
89 }) {
90 Ok(_) => {}
91 Err(e) => {
92 panic!("error sending stream to accept: {}", e);
93 }
94 }
95
96 sender
97 });
98
99 match stream_sender.value_mut().write_all(msg.data.as_ref()).await {
100 Ok(_) => {}
101 Err(e) => {
102 panic!("error writing to stream: {}", e);
103 }
104 }
105 })
106 }));
107
108 let on_close_stream_sender = stream_receiver.clone();
109
110 data_channel.on_close(Box::new(move || {
111 let pinned_stream_sender = on_close_stream_sender.clone();
112
113 Box::pin(async move {
114 pinned_stream_sender.lock().await.close();
115 })
116 }));
117
118 Self {
119 receiver: stream_receiver,
120 }
121 }
122
123 pub async fn accept(&self) -> io::Result<RTCDataChannelStream> {
124 match self.receiver.lock().await.recv().await {
125 Some(stream) => Ok(stream),
126 None => Err(io::Error::new(
127 io::ErrorKind::Other,
128 "accept called on closed RTCListener",
129 )),
130 }
131 }
132
133 pub async fn close(&self) {
134 self.receiver.lock().await.close();
135 }
136}
137
138#[cfg(feature = "axum")]
139impl axum::serve::Listener for RTCListener {
140 type Addr = ();
141 type Io = RTCDataChannelStream;
142
143 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
144 loop {
145 match Self::accept(self).await {
146 Ok(stream) => return (stream, ()),
147 Err(e) => {
148 eprintln!("error accepting stream: {}", e);
149 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
150 }
151 }
152 }
153 }
154
155 fn local_addr(&self) -> io::Result<Self::Addr> {
156 Ok(())
157 }
158}