webrtc_data/data_channel/
mod.rs1#[cfg(test)]
2mod data_channel_test;
3
4use std::borrow::Borrow;
5use std::future::Future;
6use std::net::Shutdown;
7use std::pin::Pin;
8use std::sync::atomic::Ordering;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use std::{fmt, io};
12
13use bytes::{Buf, Bytes};
14use portable_atomic::AtomicUsize;
15use sctp::association::Association;
16use sctp::chunk::chunk_payload_data::PayloadProtocolIdentifier;
17use sctp::stream::*;
18use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
19use util::marshal::*;
20
21use crate::error::{Error, Result};
22use crate::message::message_channel_ack::*;
23use crate::message::message_channel_open::*;
24use crate::message::*;
25
26const RECEIVE_MTU: usize = 8192;
27
28#[derive(Eq, PartialEq, Default, Clone, Debug)]
30pub struct Config {
31 pub channel_type: ChannelType,
32 pub negotiated: bool,
33 pub priority: u16,
34 pub reliability_parameter: u32,
35 pub label: String,
36 pub protocol: String,
37 pub max_message_size: u32,
38}
39
40#[derive(Debug, Clone)]
42pub struct DataChannel {
43 pub config: Config,
44 stream: Arc<Stream>,
45
46 messages_sent: Arc<AtomicUsize>,
48 messages_received: Arc<AtomicUsize>,
49 bytes_sent: Arc<AtomicUsize>,
50 bytes_received: Arc<AtomicUsize>,
51}
52
53impl DataChannel {
54 pub fn new(stream: Arc<Stream>, config: Config) -> Self {
55 Self {
56 config,
57 stream,
58
59 messages_sent: Arc::new(AtomicUsize::default()),
60 messages_received: Arc::new(AtomicUsize::default()),
61 bytes_sent: Arc::new(AtomicUsize::default()),
62 bytes_received: Arc::new(AtomicUsize::default()),
63 }
64 }
65
66 pub async fn dial(
68 association: &Arc<Association>,
69 identifier: u16,
70 config: Config,
71 ) -> Result<Self> {
72 let stream = association
73 .open_stream(identifier, PayloadProtocolIdentifier::Binary)
74 .await?;
75
76 Self::client(stream, config).await
77 }
78
79 pub async fn accept<T>(
81 association: &Arc<Association>,
82 config: Config,
83 existing_channels: &[T],
84 ) -> Result<Self>
85 where
86 T: Borrow<Self>,
87 {
88 let stream = association
89 .accept_stream()
90 .await
91 .ok_or(Error::ErrStreamClosed)?;
92
93 for channel in existing_channels.iter().map(|ch| ch.borrow()) {
94 if channel.stream_identifier() == stream.stream_identifier() {
95 let ch = channel.to_owned();
96 ch.stream
97 .set_default_payload_type(PayloadProtocolIdentifier::Binary);
98 return Ok(ch);
99 }
100 }
101
102 stream.set_default_payload_type(PayloadProtocolIdentifier::Binary);
103
104 Self::server(stream, config).await
105 }
106
107 pub async fn client(stream: Arc<Stream>, config: Config) -> Result<Self> {
109 if !config.negotiated {
110 let msg = Message::DataChannelOpen(DataChannelOpen {
111 channel_type: config.channel_type,
112 priority: config.priority,
113 reliability_parameter: config.reliability_parameter,
114 label: config.label.bytes().collect(),
115 protocol: config.protocol.bytes().collect(),
116 })
117 .marshal()?;
118
119 stream
120 .write_sctp(&msg, PayloadProtocolIdentifier::Dcep)
121 .await?;
122 }
123 Ok(DataChannel::new(stream, config))
124 }
125
126 pub async fn server(stream: Arc<Stream>, mut config: Config) -> Result<Self> {
128 let mut buf = vec![0u8; RECEIVE_MTU];
129
130 let (n, ppi) = stream.read_sctp(&mut buf).await?;
131
132 if ppi != PayloadProtocolIdentifier::Dcep {
133 return Err(Error::InvalidPayloadProtocolIdentifier(ppi as u8));
134 }
135
136 let mut read_buf = &buf[..n];
137 let msg = Message::unmarshal(&mut read_buf)?;
138
139 if let Message::DataChannelOpen(dco) = msg {
140 config.channel_type = dco.channel_type;
141 config.priority = dco.priority;
142 config.reliability_parameter = dco.reliability_parameter;
143 config.label = String::from_utf8(dco.label)?;
144 config.protocol = String::from_utf8(dco.protocol)?;
145 } else {
146 return Err(Error::InvalidMessageType(msg.message_type() as u8));
147 };
148
149 let data_channel = DataChannel::new(stream, config);
150
151 data_channel.write_data_channel_ack().await?;
152 data_channel.commit_reliability_params();
153
154 Ok(data_channel)
155 }
156
157 pub async fn read(&self, buf: &mut [u8]) -> Result<usize> {
161 self.read_data_channel(buf).await.map(|(n, _)| n)
162 }
163
164 pub async fn read_data_channel(&self, buf: &mut [u8]) -> Result<(usize, bool)> {
169 loop {
170 let (mut n, ppi) = match self.stream.read_sctp(buf).await {
172 Ok((0, PayloadProtocolIdentifier::Unknown)) => {
173 return Ok((0, false));
175 }
176 Ok((n, ppi)) => (n, ppi),
177 Err(err) => {
178 self.close().await?;
180 return Err(err.into());
181 }
182 };
183
184 let mut is_string = false;
185 match ppi {
186 PayloadProtocolIdentifier::Dcep => {
187 let mut data = &buf[..n];
188 match self.handle_dcep(&mut data).await {
189 Ok(()) => {}
190 Err(err) => {
191 log::error!("Failed to handle DCEP: {err:?}");
192 }
193 }
194 continue;
195 }
196 PayloadProtocolIdentifier::String | PayloadProtocolIdentifier::StringEmpty => {
197 is_string = true;
198 }
199 _ => {}
200 };
201
202 match ppi {
203 PayloadProtocolIdentifier::StringEmpty | PayloadProtocolIdentifier::BinaryEmpty => {
204 n = 0;
205 }
206 _ => {}
207 };
208
209 self.messages_received.fetch_add(1, Ordering::SeqCst);
210 self.bytes_received.fetch_add(n, Ordering::SeqCst);
211
212 return Ok((n, is_string));
213 }
214 }
215
216 pub fn messages_sent(&self) -> usize {
218 self.messages_sent.load(Ordering::SeqCst)
219 }
220
221 pub fn messages_received(&self) -> usize {
223 self.messages_received.load(Ordering::SeqCst)
224 }
225
226 pub fn bytes_sent(&self) -> usize {
228 self.bytes_sent.load(Ordering::SeqCst)
229 }
230
231 pub fn bytes_received(&self) -> usize {
233 self.bytes_received.load(Ordering::SeqCst)
234 }
235
236 pub fn stream_identifier(&self) -> u16 {
238 self.stream.stream_identifier()
239 }
240
241 async fn handle_dcep<B>(&self, data: &mut B) -> Result<()>
242 where
243 B: Buf,
244 {
245 let msg = Message::unmarshal(data)?;
246
247 match msg {
248 Message::DataChannelOpen(_) => {
249 log::debug!("Received DATA_CHANNEL_OPEN");
252 let _ = self.write_data_channel_ack().await?;
253 }
254 Message::DataChannelAck(_) => {
255 log::debug!("Received DATA_CHANNEL_ACK");
256 self.commit_reliability_params();
257 }
258 };
259
260 Ok(())
261 }
262
263 pub async fn write(&self, data: &Bytes) -> Result<usize> {
265 self.write_data_channel(data, false).await
266 }
267
268 pub async fn write_data_channel(&self, data: &Bytes, is_string: bool) -> Result<usize> {
270 let data_len = data.len();
271
272 let ppi = match (is_string, data_len) {
280 (false, 0) => PayloadProtocolIdentifier::BinaryEmpty,
281 (false, _) => PayloadProtocolIdentifier::Binary,
282 (true, 0) => PayloadProtocolIdentifier::StringEmpty,
283 (true, _) => PayloadProtocolIdentifier::String,
284 };
285
286 let n = if data_len == 0 {
287 let _ = self
288 .stream
289 .write_sctp(&Bytes::from_static(&[0]), ppi)
290 .await?;
291 0
292 } else {
293 let n = self.stream.write_sctp(data, ppi).await?;
294 self.bytes_sent.fetch_add(n, Ordering::SeqCst);
295 n
296 };
297
298 self.messages_sent.fetch_add(1, Ordering::SeqCst);
299 Ok(n)
300 }
301
302 async fn write_data_channel_ack(&self) -> Result<usize> {
303 let ack = Message::DataChannelAck(DataChannelAck {}).marshal()?;
304 Ok(self
305 .stream
306 .write_sctp(&ack, PayloadProtocolIdentifier::Dcep)
307 .await?)
308 }
309
310 pub async fn close(&self) -> Result<()> {
312 Ok(self.stream.shutdown(Shutdown::Both).await?)
324 }
325
326 pub fn buffered_amount(&self) -> usize {
329 self.stream.buffered_amount()
330 }
331
332 pub fn buffered_amount_low_threshold(&self) -> usize {
335 self.stream.buffered_amount_low_threshold()
336 }
337
338 pub fn set_buffered_amount_low_threshold(&self, threshold: usize) {
341 self.stream.set_buffered_amount_low_threshold(threshold)
342 }
343
344 pub fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) {
347 self.stream.on_buffered_amount_low(f)
348 }
349
350 fn commit_reliability_params(&self) {
351 let (unordered, reliability_type) = match self.config.channel_type {
352 ChannelType::Reliable => (false, ReliabilityType::Reliable),
353 ChannelType::ReliableUnordered => (true, ReliabilityType::Reliable),
354 ChannelType::PartialReliableRexmit => (false, ReliabilityType::Rexmit),
355 ChannelType::PartialReliableRexmitUnordered => (true, ReliabilityType::Rexmit),
356 ChannelType::PartialReliableTimed => (false, ReliabilityType::Timed),
357 ChannelType::PartialReliableTimedUnordered => (true, ReliabilityType::Timed),
358 };
359
360 self.stream.set_reliability_params(
361 unordered,
362 reliability_type,
363 self.config.reliability_parameter,
364 );
365 }
366}
367
368const DEFAULT_READ_BUF_SIZE: usize = 8192;
370
371enum ReadFut {
373 Idle,
375 Reading(Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>),
377 RemainingData(Vec<u8>),
379}
380
381impl ReadFut {
382 fn get_reading_mut(&mut self) -> &mut Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
388 match self {
389 ReadFut::Reading(ref mut fut) => fut,
390 _ => panic!("expected ReadFut to be Reading"),
391 }
392 }
393}
394
395pub struct PollDataChannel {
401 data_channel: Arc<DataChannel>,
402
403 read_fut: ReadFut,
404 write_fut: Option<Pin<Box<dyn Future<Output = Result<usize>> + Send>>>,
405 shutdown_fut: Option<Pin<Box<dyn Future<Output = Result<()>> + Send>>>,
406
407 read_buf_cap: usize,
408}
409
410impl PollDataChannel {
411 pub fn new(data_channel: Arc<DataChannel>) -> Self {
413 Self {
414 data_channel,
415 read_fut: ReadFut::Idle,
416 write_fut: None,
417 shutdown_fut: None,
418 read_buf_cap: DEFAULT_READ_BUF_SIZE,
419 }
420 }
421
422 pub fn into_inner(self) -> Arc<DataChannel> {
424 self.data_channel
425 }
426
427 pub fn clone_inner(&self) -> Arc<DataChannel> {
429 self.data_channel.clone()
430 }
431
432 pub fn messages_sent(&self) -> usize {
434 self.data_channel.messages_sent()
435 }
436
437 pub fn messages_received(&self) -> usize {
439 self.data_channel.messages_received()
440 }
441
442 pub fn bytes_sent(&self) -> usize {
444 self.data_channel.bytes_sent()
445 }
446
447 pub fn bytes_received(&self) -> usize {
449 self.data_channel.bytes_received()
450 }
451
452 pub fn stream_identifier(&self) -> u16 {
454 self.data_channel.stream_identifier()
455 }
456
457 pub fn buffered_amount(&self) -> usize {
460 self.data_channel.buffered_amount()
461 }
462
463 pub fn buffered_amount_low_threshold(&self) -> usize {
466 self.data_channel.buffered_amount_low_threshold()
467 }
468
469 pub fn set_read_buf_capacity(&mut self, capacity: usize) {
471 self.read_buf_cap = capacity
472 }
473}
474
475impl AsyncRead for PollDataChannel {
476 fn poll_read(
477 mut self: Pin<&mut Self>,
478 cx: &mut Context<'_>,
479 buf: &mut ReadBuf<'_>,
480 ) -> Poll<io::Result<()>> {
481 if buf.remaining() == 0 {
482 return Poll::Ready(Ok(()));
483 }
484
485 let fut = match self.read_fut {
486 ReadFut::Idle => {
487 let data_channel = self.data_channel.clone();
490 let mut temp_buf = vec![0; self.read_buf_cap];
491 self.read_fut = ReadFut::Reading(Box::pin(async move {
492 data_channel.read(temp_buf.as_mut_slice()).await.map(|n| {
493 temp_buf.truncate(n);
494 temp_buf
495 })
496 }));
497 self.read_fut.get_reading_mut()
498 }
499 ReadFut::Reading(ref mut fut) => fut,
500 ReadFut::RemainingData(ref mut data) => {
501 let remaining = buf.remaining();
502 let len = std::cmp::min(data.len(), remaining);
503 buf.put_slice(&data[..len]);
504 if data.len() > remaining {
505 data.drain(..len);
507 } else {
508 self.read_fut = ReadFut::Idle;
509 }
510 return Poll::Ready(Ok(()));
511 }
512 };
513
514 loop {
515 match fut.as_mut().poll(cx) {
516 Poll::Pending => return Poll::Pending,
517 Poll::Ready(Err(Error::Sctp(sctp::Error::ErrTryAgain))) => {}
520 Poll::Ready(Err(Error::Sctp(sctp::Error::ErrEof))) => {
522 self.read_fut = ReadFut::Idle;
523 return Poll::Ready(Ok(()));
524 }
525 Poll::Ready(Err(e)) => {
526 self.read_fut = ReadFut::Idle;
527 return Poll::Ready(Err(e.into()));
528 }
529 Poll::Ready(Ok(mut temp_buf)) => {
530 let remaining = buf.remaining();
531 let len = std::cmp::min(temp_buf.len(), remaining);
532 buf.put_slice(&temp_buf[..len]);
533 if temp_buf.len() > remaining {
534 temp_buf.drain(..len);
535 self.read_fut = ReadFut::RemainingData(temp_buf);
536 } else {
537 self.read_fut = ReadFut::Idle;
538 }
539 return Poll::Ready(Ok(()));
540 }
541 }
542 }
543 }
544}
545
546impl AsyncWrite for PollDataChannel {
547 fn poll_write(
548 mut self: Pin<&mut Self>,
549 cx: &mut Context<'_>,
550 buf: &[u8],
551 ) -> Poll<io::Result<usize>> {
552 if buf.is_empty() {
553 return Poll::Ready(Ok(0));
554 }
555
556 if let Some(fut) = self.write_fut.as_mut() {
557 match fut.as_mut().poll(cx) {
558 Poll::Pending => Poll::Pending,
559 Poll::Ready(Err(e)) => {
560 let data_channel = self.data_channel.clone();
561 let bytes = Bytes::copy_from_slice(buf);
562 self.write_fut =
563 Some(Box::pin(async move { data_channel.write(&bytes).await }));
564 Poll::Ready(Err(e.into()))
565 }
566 Poll::Ready(Ok(_)) => {
571 let data_channel = self.data_channel.clone();
572 let bytes = Bytes::copy_from_slice(buf);
573 self.write_fut =
574 Some(Box::pin(async move { data_channel.write(&bytes).await }));
575 Poll::Ready(Ok(buf.len()))
576 }
577 }
578 } else {
579 let data_channel = self.data_channel.clone();
580 let bytes = Bytes::copy_from_slice(buf);
581 let fut = self
582 .write_fut
583 .insert(Box::pin(async move { data_channel.write(&bytes).await }));
584
585 match fut.as_mut().poll(cx) {
586 Poll::Pending => Poll::Ready(Ok(buf.len())),
594 Poll::Ready(Err(e)) => {
595 self.write_fut = None;
596 Poll::Ready(Err(e.into()))
597 }
598 Poll::Ready(Ok(n)) => {
599 self.write_fut = None;
600 Poll::Ready(Ok(n))
601 }
602 }
603 }
604 }
605
606 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
607 match self.write_fut.as_mut() {
608 Some(fut) => match fut.as_mut().poll(cx) {
609 Poll::Pending => Poll::Pending,
610 Poll::Ready(Err(e)) => {
611 self.write_fut = None;
612 Poll::Ready(Err(e.into()))
613 }
614 Poll::Ready(Ok(_)) => {
615 self.write_fut = None;
616 Poll::Ready(Ok(()))
617 }
618 },
619 None => Poll::Ready(Ok(())),
620 }
621 }
622
623 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
624 match self.as_mut().poll_flush(cx) {
625 Poll::Pending => return Poll::Pending,
626 Poll::Ready(_) => {}
627 }
628
629 let fut = match self.shutdown_fut.as_mut() {
630 Some(fut) => fut,
631 None => {
632 let data_channel = self.data_channel.clone();
633 self.shutdown_fut.get_or_insert(Box::pin(async move {
634 data_channel
635 .stream
636 .shutdown(Shutdown::Write)
637 .await
638 .map_err(Error::Sctp)
639 }))
640 }
641 };
642
643 match fut.as_mut().poll(cx) {
644 Poll::Pending => Poll::Pending,
645 Poll::Ready(Err(e)) => {
646 self.shutdown_fut = None;
647 Poll::Ready(Err(e.into()))
648 }
649 Poll::Ready(Ok(_)) => {
650 self.shutdown_fut = None;
651 Poll::Ready(Ok(()))
652 }
653 }
654 }
655}
656
657impl Clone for PollDataChannel {
658 fn clone(&self) -> PollDataChannel {
659 PollDataChannel::new(self.clone_inner())
660 }
661}
662
663impl fmt::Debug for PollDataChannel {
664 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
665 f.debug_struct("PollDataChannel")
666 .field("data_channel", &self.data_channel)
667 .field("read_buf_cap", &self.read_buf_cap)
668 .finish()
669 }
670}
671
672impl AsRef<DataChannel> for PollDataChannel {
673 fn as_ref(&self) -> &DataChannel {
674 &self.data_channel
675 }
676}