1use std::sync::Arc;
2
3use dashmap::DashMap;
4use tokio::{
5 io::{simplex, AsyncWriteExt, SimplexStream, WriteHalf},
6 sync::oneshot::channel,
7};
8use webrtc::data_channel::{data_channel_message::DataChannelMessage, RTCDataChannel};
9
10use crate::util::{PollRTCDataChannel, RTCDataChannelStream, DEFAULT_READ_BUF_SIZE};
11
12#[derive(Clone)]
13pub struct RTCClient {
14 streams: Arc<DashMap<u32, WriteHalf<SimplexStream>>>,
15 data_channel: Arc<RTCDataChannel>,
16}
17
18unsafe impl Send for RTCClient {}
19unsafe impl Sync for RTCClient {}
20
21impl RTCClient {
22 pub fn new(data_channel: Arc<RTCDataChannel>) -> Self {
23 let streams: Arc<DashMap<u32, WriteHalf<SimplexStream>>> = Arc::new(DashMap::new());
24
25 let on_message_streams = streams.clone();
26
27 data_channel.on_message(Box::new(move |mut msg: DataChannelMessage| {
28 if msg.data.len() < 4 {
29 eprintln!("received message with less than 4 bytes");
30 return Box::pin(async move {});
31 }
32 let stream_id_bytes = msg.data.split_to(4);
33 let mut stream_id_byte_slice: [u8; 4] = [0u8; 4];
34 stream_id_byte_slice.clone_from_slice(stream_id_bytes.as_ref());
35 let stream_id = u32::from_be_bytes(stream_id_byte_slice);
36
37 let pinned_streams = on_message_streams.clone();
38
39 Box::pin(async move {
40 if let Some(mut stream_sender) = pinned_streams.get_mut(&stream_id) {
41 match stream_sender.value_mut().write_all(msg.data.as_ref()).await {
42 Ok(_) => {}
43 Err(e) => {
44 panic!("error writing to stream: {}", e);
45 }
46 }
47 } else {
48 eprintln!("received message for unknown stream");
49 }
50 })
51 }));
52
53 data_channel.on_close(Box::new(move || {
54 Box::pin(async move {
55 })
57 }));
58
59 Self {
60 streams,
61 data_channel,
62 }
63 }
64
65 pub async fn connect(&self) -> RTCDataChannelStream {
66 let data_channel = self.data_channel.clone();
67 let mut max_message_size = if let Some(weak_transport) = data_channel.transport().await {
68 if let Some(transport) = weak_transport.upgrade() {
69 transport.get_capabilities().max_message_size as usize
70 } else {
71 DEFAULT_READ_BUF_SIZE
72 }
73 } else {
74 DEFAULT_READ_BUF_SIZE
75 };
76 if max_message_size <= 4 {
77 max_message_size = DEFAULT_READ_BUF_SIZE - 4;
78 }
79
80 let (receiver, sender) = simplex(max_message_size);
81 let stream_id = rand::random::<u32>();
82 let (shutdown_sender, shutdown_receiver) = channel();
83
84 let remove_streams = self.streams.clone();
85 tokio::task::spawn(async move {
86 match shutdown_receiver.await {
87 Ok(_) => {
88 remove_streams.remove(&stream_id);
89 }
90 Err(e) => {
91 eprintln!("error receiving shutdown: {}", e);
92 }
93 }
94 });
95
96 let poll_data_channel = PollRTCDataChannel::new(stream_id, self.data_channel.clone());
97 let stream = RTCDataChannelStream::new(poll_data_channel, receiver, shutdown_sender);
98 self.streams.insert(stream_id, sender);
99
100 stream
101 }
102}