webrtc_data/data_channel/
mod.rs

1#[cfg(test)]
2mod data_channel_test;
3
4use std::borrow::Borrow;
5use std::future::Future;
6use std::net::Shutdown;
7use std::pin::Pin;
8use std::sync::atomic::Ordering;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use std::{fmt, io};
12
13use bytes::{Buf, Bytes};
14use portable_atomic::AtomicUsize;
15use sctp::association::Association;
16use sctp::chunk::chunk_payload_data::PayloadProtocolIdentifier;
17use sctp::stream::*;
18use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
19use util::marshal::*;
20
21use crate::error::{Error, Result};
22use crate::message::message_channel_ack::*;
23use crate::message::message_channel_open::*;
24use crate::message::*;
25
26const RECEIVE_MTU: usize = 8192;
27
28/// Config is used to configure the data channel.
29#[derive(Eq, PartialEq, Default, Clone, Debug)]
30pub struct Config {
31    pub channel_type: ChannelType,
32    pub negotiated: bool,
33    pub priority: u16,
34    pub reliability_parameter: u32,
35    pub label: String,
36    pub protocol: String,
37    pub max_message_size: u32,
38}
39
40/// DataChannel represents a data channel
41#[derive(Debug, Clone)]
42pub struct DataChannel {
43    pub config: Config,
44    stream: Arc<Stream>,
45
46    // stats
47    messages_sent: Arc<AtomicUsize>,
48    messages_received: Arc<AtomicUsize>,
49    bytes_sent: Arc<AtomicUsize>,
50    bytes_received: Arc<AtomicUsize>,
51}
52
53impl DataChannel {
54    pub fn new(stream: Arc<Stream>, config: Config) -> Self {
55        Self {
56            config,
57            stream,
58
59            messages_sent: Arc::new(AtomicUsize::default()),
60            messages_received: Arc::new(AtomicUsize::default()),
61            bytes_sent: Arc::new(AtomicUsize::default()),
62            bytes_received: Arc::new(AtomicUsize::default()),
63        }
64    }
65
66    /// Dial opens a data channels over SCTP
67    pub async fn dial(
68        association: &Arc<Association>,
69        identifier: u16,
70        config: Config,
71    ) -> Result<Self> {
72        let stream = association
73            .open_stream(identifier, PayloadProtocolIdentifier::Binary)
74            .await?;
75
76        Self::client(stream, config).await
77    }
78
79    /// Accept is used to accept incoming data channels over SCTP
80    pub async fn accept<T>(
81        association: &Arc<Association>,
82        config: Config,
83        existing_channels: &[T],
84    ) -> Result<Self>
85    where
86        T: Borrow<Self>,
87    {
88        let stream = association
89            .accept_stream()
90            .await
91            .ok_or(Error::ErrStreamClosed)?;
92
93        for channel in existing_channels.iter().map(|ch| ch.borrow()) {
94            if channel.stream_identifier() == stream.stream_identifier() {
95                let ch = channel.to_owned();
96                ch.stream
97                    .set_default_payload_type(PayloadProtocolIdentifier::Binary);
98                return Ok(ch);
99            }
100        }
101
102        stream.set_default_payload_type(PayloadProtocolIdentifier::Binary);
103
104        Self::server(stream, config).await
105    }
106
107    /// Client opens a data channel over an SCTP stream
108    pub async fn client(stream: Arc<Stream>, config: Config) -> Result<Self> {
109        if !config.negotiated {
110            let msg = Message::DataChannelOpen(DataChannelOpen {
111                channel_type: config.channel_type,
112                priority: config.priority,
113                reliability_parameter: config.reliability_parameter,
114                label: config.label.bytes().collect(),
115                protocol: config.protocol.bytes().collect(),
116            })
117            .marshal()?;
118
119            stream
120                .write_sctp(&msg, PayloadProtocolIdentifier::Dcep)
121                .await?;
122        }
123        Ok(DataChannel::new(stream, config))
124    }
125
126    /// Server accepts a data channel over an SCTP stream
127    pub async fn server(stream: Arc<Stream>, mut config: Config) -> Result<Self> {
128        let mut buf = vec![0u8; RECEIVE_MTU];
129
130        let (n, ppi) = stream.read_sctp(&mut buf).await?;
131
132        if ppi != PayloadProtocolIdentifier::Dcep {
133            return Err(Error::InvalidPayloadProtocolIdentifier(ppi as u8));
134        }
135
136        let mut read_buf = &buf[..n];
137        let msg = Message::unmarshal(&mut read_buf)?;
138
139        if let Message::DataChannelOpen(dco) = msg {
140            config.channel_type = dco.channel_type;
141            config.priority = dco.priority;
142            config.reliability_parameter = dco.reliability_parameter;
143            config.label = String::from_utf8(dco.label)?;
144            config.protocol = String::from_utf8(dco.protocol)?;
145        } else {
146            return Err(Error::InvalidMessageType(msg.message_type() as u8));
147        };
148
149        let data_channel = DataChannel::new(stream, config);
150
151        data_channel.write_data_channel_ack().await?;
152        data_channel.commit_reliability_params();
153
154        Ok(data_channel)
155    }
156
157    /// Read reads a packet of len(p) bytes as binary data.
158    ///
159    /// See [`sctp::stream::Stream::read_sctp`].
160    pub async fn read(&self, buf: &mut [u8]) -> Result<usize> {
161        self.read_data_channel(buf).await.map(|(n, _)| n)
162    }
163
164    /// ReadDataChannel reads a packet of len(p) bytes. It returns the number of bytes read and
165    /// `true` if the data read is a string.
166    ///
167    /// See [`sctp::stream::Stream::read_sctp`].
168    pub async fn read_data_channel(&self, buf: &mut [u8]) -> Result<(usize, bool)> {
169        loop {
170            //TODO: add handling of cancel read_data_channel
171            let (mut n, ppi) = match self.stream.read_sctp(buf).await {
172                Ok((0, PayloadProtocolIdentifier::Unknown)) => {
173                    // The incoming stream was reset or the reading half was shutdown
174                    return Ok((0, false));
175                }
176                Ok((n, ppi)) => (n, ppi),
177                Err(err) => {
178                    // Shutdown the stream and send the reset request to the remote.
179                    self.close().await?;
180                    return Err(err.into());
181                }
182            };
183
184            let mut is_string = false;
185            match ppi {
186                PayloadProtocolIdentifier::Dcep => {
187                    let mut data = &buf[..n];
188                    match self.handle_dcep(&mut data).await {
189                        Ok(()) => {}
190                        Err(err) => {
191                            log::error!("Failed to handle DCEP: {err:?}");
192                        }
193                    }
194                    continue;
195                }
196                PayloadProtocolIdentifier::String | PayloadProtocolIdentifier::StringEmpty => {
197                    is_string = true;
198                }
199                _ => {}
200            };
201
202            match ppi {
203                PayloadProtocolIdentifier::StringEmpty | PayloadProtocolIdentifier::BinaryEmpty => {
204                    n = 0;
205                }
206                _ => {}
207            };
208
209            self.messages_received.fetch_add(1, Ordering::SeqCst);
210            self.bytes_received.fetch_add(n, Ordering::SeqCst);
211
212            return Ok((n, is_string));
213        }
214    }
215
216    /// MessagesSent returns the number of messages sent
217    pub fn messages_sent(&self) -> usize {
218        self.messages_sent.load(Ordering::SeqCst)
219    }
220
221    /// MessagesReceived returns the number of messages received
222    pub fn messages_received(&self) -> usize {
223        self.messages_received.load(Ordering::SeqCst)
224    }
225
226    /// BytesSent returns the number of bytes sent
227    pub fn bytes_sent(&self) -> usize {
228        self.bytes_sent.load(Ordering::SeqCst)
229    }
230
231    /// BytesReceived returns the number of bytes received
232    pub fn bytes_received(&self) -> usize {
233        self.bytes_received.load(Ordering::SeqCst)
234    }
235
236    /// StreamIdentifier returns the Stream identifier associated to the stream.
237    pub fn stream_identifier(&self) -> u16 {
238        self.stream.stream_identifier()
239    }
240
241    async fn handle_dcep<B>(&self, data: &mut B) -> Result<()>
242    where
243        B: Buf,
244    {
245        let msg = Message::unmarshal(data)?;
246
247        match msg {
248            Message::DataChannelOpen(_) => {
249                // Note: DATA_CHANNEL_OPEN message is handled inside Server() method.
250                // Therefore, the message will not reach here.
251                log::debug!("Received DATA_CHANNEL_OPEN");
252                let _ = self.write_data_channel_ack().await?;
253            }
254            Message::DataChannelAck(_) => {
255                log::debug!("Received DATA_CHANNEL_ACK");
256                self.commit_reliability_params();
257            }
258        };
259
260        Ok(())
261    }
262
263    /// Write writes len(p) bytes from p as binary data
264    pub async fn write(&self, data: &Bytes) -> Result<usize> {
265        self.write_data_channel(data, false).await
266    }
267
268    /// WriteDataChannel writes len(p) bytes from p
269    pub async fn write_data_channel(&self, data: &Bytes, is_string: bool) -> Result<usize> {
270        let data_len = data.len();
271
272        // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-12#section-6.6
273        // SCTP does not support the sending of empty user messages.  Therefore,
274        // if an empty message has to be sent, the appropriate PPID (WebRTC
275        // String Empty or WebRTC Binary Empty) is used and the SCTP user
276        // message of one zero byte is sent.  When receiving an SCTP user
277        // message with one of these PPIDs, the receiver MUST ignore the SCTP
278        // user message and process it as an empty message.
279        let ppi = match (is_string, data_len) {
280            (false, 0) => PayloadProtocolIdentifier::BinaryEmpty,
281            (false, _) => PayloadProtocolIdentifier::Binary,
282            (true, 0) => PayloadProtocolIdentifier::StringEmpty,
283            (true, _) => PayloadProtocolIdentifier::String,
284        };
285
286        let n = if data_len == 0 {
287            let _ = self
288                .stream
289                .write_sctp(&Bytes::from_static(&[0]), ppi)
290                .await?;
291            0
292        } else {
293            let n = self.stream.write_sctp(data, ppi).await?;
294            self.bytes_sent.fetch_add(n, Ordering::SeqCst);
295            n
296        };
297
298        self.messages_sent.fetch_add(1, Ordering::SeqCst);
299        Ok(n)
300    }
301
302    async fn write_data_channel_ack(&self) -> Result<usize> {
303        let ack = Message::DataChannelAck(DataChannelAck {}).marshal()?;
304        Ok(self
305            .stream
306            .write_sctp(&ack, PayloadProtocolIdentifier::Dcep)
307            .await?)
308    }
309
310    /// Close closes the DataChannel and the underlying SCTP stream.
311    pub async fn close(&self) -> Result<()> {
312        // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7
313        // Closing of a data channel MUST be signaled by resetting the
314        // corresponding outgoing streams [RFC6525].  This means that if one
315        // side decides to close the data channel, it resets the corresponding
316        // outgoing stream.  When the peer sees that an incoming stream was
317        // reset, it also resets its corresponding outgoing stream.  Once this
318        // is completed, the data channel is closed.  Resetting a stream sets
319        // the Stream Sequence Numbers (SSNs) of the stream back to 'zero' with
320        // a corresponding notification to the application layer that the reset
321        // has been performed.  Streams are available for reuse after a reset
322        // has been performed.
323        Ok(self.stream.shutdown(Shutdown::Both).await?)
324    }
325
326    /// BufferedAmount returns the number of bytes of data currently queued to be
327    /// sent over this stream.
328    pub fn buffered_amount(&self) -> usize {
329        self.stream.buffered_amount()
330    }
331
332    /// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing
333    /// data that is considered "low." Defaults to 0.
334    pub fn buffered_amount_low_threshold(&self) -> usize {
335        self.stream.buffered_amount_low_threshold()
336    }
337
338    /// SetBufferedAmountLowThreshold is used to update the threshold.
339    /// See BufferedAmountLowThreshold().
340    pub fn set_buffered_amount_low_threshold(&self, threshold: usize) {
341        self.stream.set_buffered_amount_low_threshold(threshold)
342    }
343
344    /// OnBufferedAmountLow sets the callback handler which would be called when the
345    /// number of bytes of outgoing data buffered is lower than the threshold.
346    pub fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) {
347        self.stream.on_buffered_amount_low(f)
348    }
349
350    fn commit_reliability_params(&self) {
351        let (unordered, reliability_type) = match self.config.channel_type {
352            ChannelType::Reliable => (false, ReliabilityType::Reliable),
353            ChannelType::ReliableUnordered => (true, ReliabilityType::Reliable),
354            ChannelType::PartialReliableRexmit => (false, ReliabilityType::Rexmit),
355            ChannelType::PartialReliableRexmitUnordered => (true, ReliabilityType::Rexmit),
356            ChannelType::PartialReliableTimed => (false, ReliabilityType::Timed),
357            ChannelType::PartialReliableTimedUnordered => (true, ReliabilityType::Timed),
358        };
359
360        self.stream.set_reliability_params(
361            unordered,
362            reliability_type,
363            self.config.reliability_parameter,
364        );
365    }
366}
367
368/// Default capacity of the temporary read buffer used by [`PollStream`].
369const DEFAULT_READ_BUF_SIZE: usize = 8192;
370
371/// State of the read `Future` in [`PollStream`].
372enum ReadFut {
373    /// Nothing in progress.
374    Idle,
375    /// Reading data from the underlying stream.
376    Reading(Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>),
377    /// Finished reading, but there's unread data in the temporary buffer.
378    RemainingData(Vec<u8>),
379}
380
381impl ReadFut {
382    /// Gets a mutable reference to the future stored inside `Reading(future)`.
383    ///
384    /// # Panics
385    ///
386    /// Panics if `ReadFut` variant is not `Reading`.
387    fn get_reading_mut(&mut self) -> &mut Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
388        match self {
389            ReadFut::Reading(ref mut fut) => fut,
390            _ => panic!("expected ReadFut to be Reading"),
391        }
392    }
393}
394
395/// A wrapper around around [`DataChannel`], which implements [`AsyncRead`] and
396/// [`AsyncWrite`].
397///
398/// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an
399/// additional overhead.
400pub struct PollDataChannel {
401    data_channel: Arc<DataChannel>,
402
403    read_fut: ReadFut,
404    write_fut: Option<Pin<Box<dyn Future<Output = Result<usize>> + Send>>>,
405    shutdown_fut: Option<Pin<Box<dyn Future<Output = Result<()>> + Send>>>,
406
407    read_buf_cap: usize,
408}
409
410impl PollDataChannel {
411    /// Constructs a new `PollDataChannel`.
412    pub fn new(data_channel: Arc<DataChannel>) -> Self {
413        Self {
414            data_channel,
415            read_fut: ReadFut::Idle,
416            write_fut: None,
417            shutdown_fut: None,
418            read_buf_cap: DEFAULT_READ_BUF_SIZE,
419        }
420    }
421
422    /// Get back the inner data_channel.
423    pub fn into_inner(self) -> Arc<DataChannel> {
424        self.data_channel
425    }
426
427    /// Obtain a clone of the inner data_channel.
428    pub fn clone_inner(&self) -> Arc<DataChannel> {
429        self.data_channel.clone()
430    }
431
432    /// MessagesSent returns the number of messages sent
433    pub fn messages_sent(&self) -> usize {
434        self.data_channel.messages_sent()
435    }
436
437    /// MessagesReceived returns the number of messages received
438    pub fn messages_received(&self) -> usize {
439        self.data_channel.messages_received()
440    }
441
442    /// BytesSent returns the number of bytes sent
443    pub fn bytes_sent(&self) -> usize {
444        self.data_channel.bytes_sent()
445    }
446
447    /// BytesReceived returns the number of bytes received
448    pub fn bytes_received(&self) -> usize {
449        self.data_channel.bytes_received()
450    }
451
452    /// StreamIdentifier returns the Stream identifier associated to the stream.
453    pub fn stream_identifier(&self) -> u16 {
454        self.data_channel.stream_identifier()
455    }
456
457    /// BufferedAmount returns the number of bytes of data currently queued to be
458    /// sent over this stream.
459    pub fn buffered_amount(&self) -> usize {
460        self.data_channel.buffered_amount()
461    }
462
463    /// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing
464    /// data that is considered "low." Defaults to 0.
465    pub fn buffered_amount_low_threshold(&self) -> usize {
466        self.data_channel.buffered_amount_low_threshold()
467    }
468
469    /// Set the capacity of the temporary read buffer (default: 8192).
470    pub fn set_read_buf_capacity(&mut self, capacity: usize) {
471        self.read_buf_cap = capacity
472    }
473}
474
475impl AsyncRead for PollDataChannel {
476    fn poll_read(
477        mut self: Pin<&mut Self>,
478        cx: &mut Context<'_>,
479        buf: &mut ReadBuf<'_>,
480    ) -> Poll<io::Result<()>> {
481        if buf.remaining() == 0 {
482            return Poll::Ready(Ok(()));
483        }
484
485        let fut = match self.read_fut {
486            ReadFut::Idle => {
487                // read into a temporary buffer because `buf` has an unonymous lifetime, which can
488                // be shorter than the lifetime of `read_fut`.
489                let data_channel = self.data_channel.clone();
490                let mut temp_buf = vec![0; self.read_buf_cap];
491                self.read_fut = ReadFut::Reading(Box::pin(async move {
492                    data_channel.read(temp_buf.as_mut_slice()).await.map(|n| {
493                        temp_buf.truncate(n);
494                        temp_buf
495                    })
496                }));
497                self.read_fut.get_reading_mut()
498            }
499            ReadFut::Reading(ref mut fut) => fut,
500            ReadFut::RemainingData(ref mut data) => {
501                let remaining = buf.remaining();
502                let len = std::cmp::min(data.len(), remaining);
503                buf.put_slice(&data[..len]);
504                if data.len() > remaining {
505                    // ReadFut remains to be RemainingData
506                    data.drain(..len);
507                } else {
508                    self.read_fut = ReadFut::Idle;
509                }
510                return Poll::Ready(Ok(()));
511            }
512        };
513
514        loop {
515            match fut.as_mut().poll(cx) {
516                Poll::Pending => return Poll::Pending,
517                // retry immediately upon empty data or incomplete chunks
518                // since there's no way to setup a waker.
519                Poll::Ready(Err(Error::Sctp(sctp::Error::ErrTryAgain))) => {}
520                // EOF has been reached => don't touch buf and just return Ok
521                Poll::Ready(Err(Error::Sctp(sctp::Error::ErrEof))) => {
522                    self.read_fut = ReadFut::Idle;
523                    return Poll::Ready(Ok(()));
524                }
525                Poll::Ready(Err(e)) => {
526                    self.read_fut = ReadFut::Idle;
527                    return Poll::Ready(Err(e.into()));
528                }
529                Poll::Ready(Ok(mut temp_buf)) => {
530                    let remaining = buf.remaining();
531                    let len = std::cmp::min(temp_buf.len(), remaining);
532                    buf.put_slice(&temp_buf[..len]);
533                    if temp_buf.len() > remaining {
534                        temp_buf.drain(..len);
535                        self.read_fut = ReadFut::RemainingData(temp_buf);
536                    } else {
537                        self.read_fut = ReadFut::Idle;
538                    }
539                    return Poll::Ready(Ok(()));
540                }
541            }
542        }
543    }
544}
545
546impl AsyncWrite for PollDataChannel {
547    fn poll_write(
548        mut self: Pin<&mut Self>,
549        cx: &mut Context<'_>,
550        buf: &[u8],
551    ) -> Poll<io::Result<usize>> {
552        if buf.is_empty() {
553            return Poll::Ready(Ok(0));
554        }
555
556        if let Some(fut) = self.write_fut.as_mut() {
557            match fut.as_mut().poll(cx) {
558                Poll::Pending => Poll::Pending,
559                Poll::Ready(Err(e)) => {
560                    let data_channel = self.data_channel.clone();
561                    let bytes = Bytes::copy_from_slice(buf);
562                    self.write_fut =
563                        Some(Box::pin(async move { data_channel.write(&bytes).await }));
564                    Poll::Ready(Err(e.into()))
565                }
566                // Given the data is buffered, it's okay to ignore the number of written bytes.
567                //
568                // TODO: In the long term, `data_channel.write` should be made sync. Then we could
569                // remove the whole `if` condition and just call `data_channel.write`.
570                Poll::Ready(Ok(_)) => {
571                    let data_channel = self.data_channel.clone();
572                    let bytes = Bytes::copy_from_slice(buf);
573                    self.write_fut =
574                        Some(Box::pin(async move { data_channel.write(&bytes).await }));
575                    Poll::Ready(Ok(buf.len()))
576                }
577            }
578        } else {
579            let data_channel = self.data_channel.clone();
580            let bytes = Bytes::copy_from_slice(buf);
581            let fut = self
582                .write_fut
583                .insert(Box::pin(async move { data_channel.write(&bytes).await }));
584
585            match fut.as_mut().poll(cx) {
586                // If it's the first time we're polling the future, `Poll::Pending` can't be
587                // returned because that would mean the `PollDataChannel` is not ready for writing.
588                // And this is not true since we've just created a future, which is going to write
589                // the buf to the underlying stream.
590                //
591                // It's okay to return `Poll::Ready` if the data is buffered (this is what the
592                // buffered writer and `File` do).
593                Poll::Pending => Poll::Ready(Ok(buf.len())),
594                Poll::Ready(Err(e)) => {
595                    self.write_fut = None;
596                    Poll::Ready(Err(e.into()))
597                }
598                Poll::Ready(Ok(n)) => {
599                    self.write_fut = None;
600                    Poll::Ready(Ok(n))
601                }
602            }
603        }
604    }
605
606    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
607        match self.write_fut.as_mut() {
608            Some(fut) => match fut.as_mut().poll(cx) {
609                Poll::Pending => Poll::Pending,
610                Poll::Ready(Err(e)) => {
611                    self.write_fut = None;
612                    Poll::Ready(Err(e.into()))
613                }
614                Poll::Ready(Ok(_)) => {
615                    self.write_fut = None;
616                    Poll::Ready(Ok(()))
617                }
618            },
619            None => Poll::Ready(Ok(())),
620        }
621    }
622
623    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
624        match self.as_mut().poll_flush(cx) {
625            Poll::Pending => return Poll::Pending,
626            Poll::Ready(_) => {}
627        }
628
629        let fut = match self.shutdown_fut.as_mut() {
630            Some(fut) => fut,
631            None => {
632                let data_channel = self.data_channel.clone();
633                self.shutdown_fut.get_or_insert(Box::pin(async move {
634                    data_channel
635                        .stream
636                        .shutdown(Shutdown::Write)
637                        .await
638                        .map_err(Error::Sctp)
639                }))
640            }
641        };
642
643        match fut.as_mut().poll(cx) {
644            Poll::Pending => Poll::Pending,
645            Poll::Ready(Err(e)) => {
646                self.shutdown_fut = None;
647                Poll::Ready(Err(e.into()))
648            }
649            Poll::Ready(Ok(_)) => {
650                self.shutdown_fut = None;
651                Poll::Ready(Ok(()))
652            }
653        }
654    }
655}
656
657impl Clone for PollDataChannel {
658    fn clone(&self) -> PollDataChannel {
659        PollDataChannel::new(self.clone_inner())
660    }
661}
662
663impl fmt::Debug for PollDataChannel {
664    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
665        f.debug_struct("PollDataChannel")
666            .field("data_channel", &self.data_channel)
667            .field("read_buf_cap", &self.read_buf_cap)
668            .finish()
669    }
670}
671
672impl AsRef<DataChannel> for PollDataChannel {
673    fn as_ref(&self) -> &DataChannel {
674        &self.data_channel
675    }
676}