1use std::{
2 future::Future,
3 io,
4 pin::Pin,
5 result::Result,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use bytes::Bytes;
11use tokio::{
12 io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, SimplexStream},
13 sync::oneshot::Sender,
14};
15use webrtc::data_channel::RTCDataChannel;
16
17pub const DEFAULT_READ_BUF_SIZE: usize = 8192;
18
19pub struct RTCDataChannelStream {
20 poll_data_channel: PollRTCDataChannel,
21 receiver: ReadHalf<SimplexStream>,
22 shutdown: Option<Sender<()>>,
23}
24
25unsafe impl Send for RTCDataChannelStream {}
26unsafe impl Sync for RTCDataChannelStream {}
27
28impl RTCDataChannelStream {
29 pub fn new(
30 poll_data_channel: PollRTCDataChannel,
31 receiver: ReadHalf<SimplexStream>,
32 shutdown: Sender<()>,
33 ) -> Self {
34 Self {
35 poll_data_channel,
36 receiver,
37 shutdown: Some(shutdown),
38 }
39 }
40
41 pub fn shutdown(&mut self) -> Result<(), ()> {
42 match self.shutdown.take() {
43 Some(shutdown) => shutdown.send(()),
44 None => Err(()),
45 }
46 }
47}
48
49impl AsyncRead for RTCDataChannelStream {
50 fn poll_read(
51 self: Pin<&mut Self>,
52 cx: &mut Context<'_>,
53 buf: &mut ReadBuf<'_>,
54 ) -> Poll<io::Result<()>> {
55 Pin::new(&mut self.get_mut().receiver).poll_read(cx, buf)
56 }
57}
58
59impl AsyncWrite for RTCDataChannelStream {
60 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
61 Pin::new(&mut self.get_mut().poll_data_channel).poll_write(cx, buf)
62 }
63
64 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
65 Pin::new(&mut self.get_mut().poll_data_channel).poll_flush(cx)
66 }
67
68 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
69 let mut_self = self.get_mut();
70 match Pin::new(&mut mut_self.poll_data_channel).poll_shutdown(cx) {
71 Poll::Pending => Poll::Pending,
72 Poll::Ready(result) => {
73 let _ = mut_self.shutdown();
74 Poll::Ready(result)
75 }
76 }
77 }
78}
79
80pub struct PollRTCDataChannel {
81 stream_id_bytes: [u8; 4],
82 data_channel: Arc<RTCDataChannel>,
83 write_fut: Option<Pin<Box<dyn Future<Output = Result<usize, webrtc::Error>> + Send>>>,
84}
85
86impl PollRTCDataChannel {
87 pub fn new(stream_id: u32, data_channel: Arc<RTCDataChannel>) -> Self {
88 Self {
89 stream_id_bytes: stream_id.to_be_bytes(),
90 data_channel,
91 write_fut: None,
92 }
93 }
94}
95
96impl AsyncWrite for PollRTCDataChannel {
97 fn poll_write(
98 mut self: Pin<&mut Self>,
99 cx: &mut Context<'_>,
100 buf: &[u8],
101 ) -> Poll<io::Result<usize>> {
102 if buf.is_empty() {
103 return Poll::Ready(Ok(0));
104 }
105
106 if let Some(fut) = self.write_fut.as_mut() {
107 match fut.as_mut().poll(cx) {
108 Poll::Pending => Poll::Pending,
109 Poll::Ready(Err(e)) => {
110 let data_channel: Arc<RTCDataChannel> = self.data_channel.clone();
111 let bytes: Bytes = Bytes::from_owner([&self.stream_id_bytes, buf].concat());
112 self.write_fut = Some(Box::pin(async move {
113 data_channel.send(&bytes).await.map(|len| len - 4)
114 }));
115 Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e.to_string())))
116 }
117 Poll::Ready(Ok(_)) => {
118 let data_channel = self.data_channel.clone();
119 let bytes: Bytes = Bytes::from_owner([&self.stream_id_bytes, buf].concat());
120 self.write_fut = Some(Box::pin(async move {
121 data_channel.send(&bytes).await.map(|len| len - 4)
122 }));
123 Poll::Ready(Ok(buf.len()))
124 }
125 }
126 } else {
127 let data_channel = self.data_channel.clone();
128 let bytes: Bytes = Bytes::from_owner([&self.stream_id_bytes, buf].concat());
129 let fut = self.write_fut.insert(Box::pin(async move {
130 data_channel.send(&bytes).await.map(|len| len - 4)
131 }));
132
133 match fut.as_mut().poll(cx) {
134 Poll::Pending => Poll::Ready(Ok(buf.len())),
135 Poll::Ready(Err(e)) => {
136 self.write_fut = None;
137 Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e.to_string())))
138 }
139 Poll::Ready(Ok(n)) => {
140 self.write_fut = None;
141 Poll::Ready(Ok(n))
142 }
143 }
144 }
145 }
146
147 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
148 match self.write_fut.as_mut() {
149 Some(fut) => match fut.as_mut().poll(cx) {
150 Poll::Pending => Poll::Pending,
151 Poll::Ready(Err(e)) => {
152 self.write_fut = None;
153 Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e.to_string())))
154 }
155 Poll::Ready(Ok(_)) => {
156 self.write_fut = None;
157 Poll::Ready(Ok(()))
158 }
159 },
160 None => Poll::Ready(Ok(())),
161 }
162 }
163
164 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
165 match self.as_mut().poll_flush(cx) {
166 Poll::Pending => return Poll::Pending,
167 Poll::Ready(_) => Poll::Ready(Ok(())),
168 }
169 }
170}