webrtc_http/
util.rs

1use std::{
2  future::Future,
3  io,
4  pin::Pin,
5  result::Result,
6  sync::Arc,
7  task::{Context, Poll},
8};
9
10use bytes::Bytes;
11use tokio::{
12  io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, SimplexStream},
13  sync::oneshot::Sender,
14};
15use webrtc::data_channel::RTCDataChannel;
16
17pub const DEFAULT_READ_BUF_SIZE: usize = 8192;
18
19pub struct RTCDataChannelStream {
20  poll_data_channel: PollRTCDataChannel,
21  receiver: ReadHalf<SimplexStream>,
22  shutdown: Option<Sender<()>>,
23}
24
25unsafe impl Send for RTCDataChannelStream {}
26unsafe impl Sync for RTCDataChannelStream {}
27
28impl RTCDataChannelStream {
29  pub fn new(
30    poll_data_channel: PollRTCDataChannel,
31    receiver: ReadHalf<SimplexStream>,
32    shutdown: Sender<()>,
33  ) -> Self {
34    Self {
35      poll_data_channel,
36      receiver,
37      shutdown: Some(shutdown),
38    }
39  }
40
41  pub fn shutdown(&mut self) -> Result<(), ()> {
42    match self.shutdown.take() {
43      Some(shutdown) => shutdown.send(()),
44      None => Err(()),
45    }
46  }
47}
48
49impl AsyncRead for RTCDataChannelStream {
50  fn poll_read(
51    self: Pin<&mut Self>,
52    cx: &mut Context<'_>,
53    buf: &mut ReadBuf<'_>,
54  ) -> Poll<io::Result<()>> {
55    Pin::new(&mut self.get_mut().receiver).poll_read(cx, buf)
56  }
57}
58
59impl AsyncWrite for RTCDataChannelStream {
60  fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
61    Pin::new(&mut self.get_mut().poll_data_channel).poll_write(cx, buf)
62  }
63
64  fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
65    Pin::new(&mut self.get_mut().poll_data_channel).poll_flush(cx)
66  }
67
68  fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
69    let mut_self = self.get_mut();
70    match Pin::new(&mut mut_self.poll_data_channel).poll_shutdown(cx) {
71      Poll::Pending => Poll::Pending,
72      Poll::Ready(result) => {
73        let _ = mut_self.shutdown();
74        Poll::Ready(result)
75      }
76    }
77  }
78}
79
80pub struct PollRTCDataChannel {
81  stream_id_bytes: [u8; 4],
82  data_channel: Arc<RTCDataChannel>,
83  write_fut: Option<Pin<Box<dyn Future<Output = Result<usize, webrtc::Error>> + Send>>>,
84}
85
86impl PollRTCDataChannel {
87  pub fn new(stream_id: u32, data_channel: Arc<RTCDataChannel>) -> Self {
88    Self {
89      stream_id_bytes: stream_id.to_be_bytes(),
90      data_channel,
91      write_fut: None,
92    }
93  }
94}
95
96impl AsyncWrite for PollRTCDataChannel {
97  fn poll_write(
98    mut self: Pin<&mut Self>,
99    cx: &mut Context<'_>,
100    buf: &[u8],
101  ) -> Poll<io::Result<usize>> {
102    if buf.is_empty() {
103      return Poll::Ready(Ok(0));
104    }
105
106    if let Some(fut) = self.write_fut.as_mut() {
107      match fut.as_mut().poll(cx) {
108        Poll::Pending => Poll::Pending,
109        Poll::Ready(Err(e)) => {
110          let data_channel: Arc<RTCDataChannel> = self.data_channel.clone();
111          let bytes: Bytes = Bytes::from_owner([&self.stream_id_bytes, buf].concat());
112          self.write_fut = Some(Box::pin(async move {
113            data_channel.send(&bytes).await.map(|len| len - 4)
114          }));
115          Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e.to_string())))
116        }
117        Poll::Ready(Ok(_)) => {
118          let data_channel = self.data_channel.clone();
119          let bytes: Bytes = Bytes::from_owner([&self.stream_id_bytes, buf].concat());
120          self.write_fut = Some(Box::pin(async move {
121            data_channel.send(&bytes).await.map(|len| len - 4)
122          }));
123          Poll::Ready(Ok(buf.len()))
124        }
125      }
126    } else {
127      let data_channel = self.data_channel.clone();
128      let bytes: Bytes = Bytes::from_owner([&self.stream_id_bytes, buf].concat());
129      let fut = self.write_fut.insert(Box::pin(async move {
130        data_channel.send(&bytes).await.map(|len| len - 4)
131      }));
132
133      match fut.as_mut().poll(cx) {
134        Poll::Pending => Poll::Ready(Ok(buf.len())),
135        Poll::Ready(Err(e)) => {
136          self.write_fut = None;
137          Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e.to_string())))
138        }
139        Poll::Ready(Ok(n)) => {
140          self.write_fut = None;
141          Poll::Ready(Ok(n))
142        }
143      }
144    }
145  }
146
147  fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
148    match self.write_fut.as_mut() {
149      Some(fut) => match fut.as_mut().poll(cx) {
150        Poll::Pending => Poll::Pending,
151        Poll::Ready(Err(e)) => {
152          self.write_fut = None;
153          Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e.to_string())))
154        }
155        Poll::Ready(Ok(_)) => {
156          self.write_fut = None;
157          Poll::Ready(Ok(()))
158        }
159      },
160      None => Poll::Ready(Ok(())),
161    }
162  }
163
164  fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
165    match self.as_mut().poll_flush(cx) {
166      Poll::Pending => return Poll::Pending,
167      Poll::Ready(_) => Poll::Ready(Ok(())),
168    }
169  }
170}