sctp_async/stream/
mod.rs

1#[cfg(test)]
2mod stream_test;
3
4use crate::association::AssociationState;
5use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier};
6use crate::error::{Error, Result};
7use crate::queue::reassembly_queue::ReassemblyQueue;
8
9use crate::queue::pending_queue::PendingQueue;
10
11use bytes::Bytes;
12use std::fmt;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicU8, AtomicUsize, Ordering};
16use std::sync::Arc;
17use tokio::sync::{mpsc, Mutex, Notify};
18
19#[derive(Debug, Copy, Clone, PartialEq)]
20#[repr(C)]
21pub enum ReliabilityType {
22    /// ReliabilityTypeReliable is used for reliable transmission
23    Reliable = 0,
24    /// ReliabilityTypeRexmit is used for partial reliability by retransmission count
25    Rexmit = 1,
26    /// ReliabilityTypeTimed is used for partial reliability by retransmission duration
27    Timed = 2,
28}
29
30impl Default for ReliabilityType {
31    fn default() -> Self {
32        ReliabilityType::Reliable
33    }
34}
35
36impl fmt::Display for ReliabilityType {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        let s = match *self {
39            ReliabilityType::Reliable => "Reliable",
40            ReliabilityType::Rexmit => "Rexmit",
41            ReliabilityType::Timed => "Timed",
42        };
43        write!(f, "{}", s)
44    }
45}
46
47impl From<u8> for ReliabilityType {
48    fn from(v: u8) -> ReliabilityType {
49        match v {
50            1 => ReliabilityType::Rexmit,
51            2 => ReliabilityType::Timed,
52            _ => ReliabilityType::Reliable,
53        }
54    }
55}
56
57pub type OnBufferedAmountLowFn =
58    Box<dyn (FnMut() -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>) + Send + Sync>;
59
60// TODO: benchmark performance between multiple Atomic+Mutex vs one Mutex<StreamInternal>
61
62/// Stream represents an SCTP stream
63#[derive(Default)]
64pub struct Stream {
65    pub(crate) max_payload_size: u32,
66    pub(crate) max_message_size: Arc<AtomicU32>, // clone from association
67    pub(crate) state: Arc<AtomicU8>,             // clone from association
68    pub(crate) awake_write_loop_ch: Option<Arc<mpsc::Sender<()>>>,
69    pub(crate) pending_queue: Arc<PendingQueue>,
70
71    pub(crate) stream_identifier: u16,
72    pub(crate) default_payload_type: AtomicU32, //PayloadProtocolIdentifier,
73    pub(crate) reassembly_queue: Mutex<ReassemblyQueue>,
74    pub(crate) sequence_number: AtomicU16,
75    pub(crate) read_notifier: Notify,
76    pub(crate) closed: AtomicBool,
77    pub(crate) unordered: AtomicBool,
78    pub(crate) reliability_type: AtomicU8, //ReliabilityType,
79    pub(crate) reliability_value: AtomicU32,
80    pub(crate) buffered_amount: AtomicUsize,
81    pub(crate) buffered_amount_low: AtomicUsize,
82    pub(crate) on_buffered_amount_low: Mutex<Option<OnBufferedAmountLowFn>>,
83    pub(crate) name: String,
84}
85
86impl fmt::Debug for Stream {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        f.debug_struct("Stream")
89            .field("max_payload_size", &self.max_payload_size)
90            .field("max_message_size", &self.max_message_size)
91            .field("state", &self.state)
92            .field("awake_write_loop_ch", &self.awake_write_loop_ch)
93            .field("stream_identifier", &self.stream_identifier)
94            .field("default_payload_type", &self.default_payload_type)
95            .field("reassembly_queue", &self.reassembly_queue)
96            .field("sequence_number", &self.sequence_number)
97            .field("closed", &self.closed)
98            .field("unordered", &self.unordered)
99            .field("reliability_type", &self.reliability_type)
100            .field("reliability_value", &self.reliability_value)
101            .field("buffered_amount", &self.buffered_amount)
102            .field("buffered_amount_low", &self.buffered_amount_low)
103            .field("name", &self.name)
104            .finish()
105    }
106}
107
108impl Stream {
109    pub(crate) fn new(
110        name: String,
111        stream_identifier: u16,
112        max_payload_size: u32,
113        max_message_size: Arc<AtomicU32>,
114        state: Arc<AtomicU8>,
115        awake_write_loop_ch: Option<Arc<mpsc::Sender<()>>>,
116        pending_queue: Arc<PendingQueue>,
117    ) -> Self {
118        Stream {
119            max_payload_size,
120            max_message_size,
121            state,
122            awake_write_loop_ch,
123            pending_queue,
124
125            stream_identifier,
126            default_payload_type: AtomicU32::new(0), //PayloadProtocolIdentifier::Unknown,
127            reassembly_queue: Mutex::new(ReassemblyQueue::new(stream_identifier)),
128            sequence_number: AtomicU16::new(0),
129            read_notifier: Notify::new(),
130            closed: AtomicBool::new(false),
131            unordered: AtomicBool::new(false),
132            reliability_type: AtomicU8::new(0), //ReliabilityType::Reliable,
133            reliability_value: AtomicU32::new(0),
134            buffered_amount: AtomicUsize::new(0),
135            buffered_amount_low: AtomicUsize::new(0),
136            on_buffered_amount_low: Mutex::new(None),
137            name,
138        }
139    }
140
141    /// stream_identifier returns the Stream identifier associated to the stream.
142    pub fn stream_identifier(&self) -> u16 {
143        self.stream_identifier
144    }
145
146    /// set_default_payload_type sets the default payload type used by write.
147    pub fn set_default_payload_type(&self, default_payload_type: PayloadProtocolIdentifier) {
148        self.default_payload_type
149            .store(default_payload_type as u32, Ordering::SeqCst);
150    }
151
152    /// set_reliability_params sets reliability parameters for this stream.
153    pub fn set_reliability_params(&self, unordered: bool, rel_type: ReliabilityType, rel_val: u32) {
154        log::debug!(
155            "[{}] reliability params: ordered={} type={} value={}",
156            self.name,
157            !unordered,
158            rel_type,
159            rel_val
160        );
161        self.unordered.store(unordered, Ordering::SeqCst);
162        self.reliability_type
163            .store(rel_type as u8, Ordering::SeqCst);
164        self.reliability_value.store(rel_val, Ordering::SeqCst);
165    }
166
167    /// read reads a packet of len(p) bytes, dropping the Payload Protocol Identifier.
168    /// Returns EOF when the stream is reset or an error if the stream is closed
169    /// otherwise.
170    pub async fn read(&self, p: &mut [u8]) -> Result<usize> {
171        let (n, _) = self.read_sctp(p).await?;
172        Ok(n)
173    }
174
175    /// read_sctp reads a packet of len(p) bytes and returns the associated Payload
176    /// Protocol Identifier.
177    /// Returns EOF when the stream is reset or an error if the stream is closed
178    /// otherwise.
179    pub async fn read_sctp(&self, p: &mut [u8]) -> Result<(usize, PayloadProtocolIdentifier)> {
180        while !self.closed.load(Ordering::SeqCst) {
181            let result = {
182                let mut reassembly_queue = self.reassembly_queue.lock().await;
183                reassembly_queue.read(p)
184            };
185
186            if result.is_ok() {
187                return result;
188            } else if let Err(err) = result {
189                if Error::ErrShortBuffer == err {
190                    return Err(err);
191                }
192            }
193
194            self.read_notifier.notified().await;
195        }
196
197        Err(Error::ErrStreamClosed)
198    }
199
200    pub(crate) async fn handle_data(&self, pd: ChunkPayloadData) {
201        let readable = {
202            let mut reassembly_queue = self.reassembly_queue.lock().await;
203            if reassembly_queue.push(pd) {
204                let readable = reassembly_queue.is_readable();
205                log::debug!("[{}] reassemblyQueue readable={}", self.name, readable);
206                readable
207            } else {
208                false
209            }
210        };
211
212        if readable {
213            log::debug!("[{}] readNotifier.signal()", self.name);
214            self.read_notifier.notify_one();
215            log::debug!("[{}] readNotifier.signal() done", self.name);
216        }
217    }
218
219    pub(crate) async fn handle_forward_tsn_for_ordered(&self, ssn: u16) {
220        if self.unordered.load(Ordering::SeqCst) {
221            return; // unordered chunks are handled by handleForwardUnordered method
222        }
223
224        // Remove all chunks older than or equal to the new TSN from
225        // the reassembly_queue.
226        let readable = {
227            let mut reassembly_queue = self.reassembly_queue.lock().await;
228            reassembly_queue.forward_tsn_for_ordered(ssn);
229            reassembly_queue.is_readable()
230        };
231
232        // Notify the reader asynchronously if there's a data chunk to read.
233        if readable {
234            self.read_notifier.notify_one();
235        }
236    }
237
238    pub(crate) async fn handle_forward_tsn_for_unordered(&self, new_cumulative_tsn: u32) {
239        if !self.unordered.load(Ordering::SeqCst) {
240            return; // ordered chunks are handled by handleForwardTSNOrdered method
241        }
242
243        // Remove all chunks older than or equal to the new TSN from
244        // the reassembly_queue.
245        let readable = {
246            let mut reassembly_queue = self.reassembly_queue.lock().await;
247            reassembly_queue.forward_tsn_for_unordered(new_cumulative_tsn);
248            reassembly_queue.is_readable()
249        };
250
251        // Notify the reader asynchronously if there's a data chunk to read.
252        if readable {
253            self.read_notifier.notify_one();
254        }
255    }
256
257    /// write writes len(p) bytes from p with the default Payload Protocol Identifier
258    pub async fn write(&self, p: &Bytes) -> Result<usize> {
259        self.write_sctp(p, self.default_payload_type.load(Ordering::SeqCst).into())
260            .await
261    }
262
263    /// write_sctp writes len(p) bytes from p to the DTLS connection
264    pub async fn write_sctp(&self, p: &Bytes, ppi: PayloadProtocolIdentifier) -> Result<usize> {
265        if p.len() > self.max_message_size.load(Ordering::SeqCst) as usize {
266            return Err(Error::ErrOutboundPacketTooLarge);
267        }
268
269        let state: AssociationState = self.state.load(Ordering::SeqCst).into();
270        match state {
271            AssociationState::ShutdownSent
272            | AssociationState::ShutdownAckSent
273            | AssociationState::ShutdownPending
274            | AssociationState::ShutdownReceived => return Err(Error::ErrStreamClosed),
275            _ => {}
276        };
277
278        let chunks = self.packetize(p, ppi);
279        self.send_payload_data(chunks).await?;
280
281        Ok(p.len())
282    }
283
284    fn packetize(&self, raw: &Bytes, ppi: PayloadProtocolIdentifier) -> Vec<ChunkPayloadData> {
285        let mut i = 0;
286        let mut remaining = raw.len();
287
288        // From draft-ietf-rtcweb-data-protocol-09, section 6:
289        //   All Data Channel Establishment Protocol messages MUST be sent using
290        //   ordered delivery and reliable transmission.
291        let unordered =
292            ppi != PayloadProtocolIdentifier::Dcep && self.unordered.load(Ordering::SeqCst);
293
294        let mut chunks = vec![];
295
296        let head_abandoned = Arc::new(AtomicBool::new(false));
297        let head_all_inflight = Arc::new(AtomicBool::new(false));
298        while remaining != 0 {
299            let fragment_size = std::cmp::min(self.max_payload_size as usize, remaining); //self.association.max_payload_size
300
301            // Copy the userdata since we'll have to store it until acked
302            // and the caller may re-use the buffer in the mean time
303            let user_data = raw.slice(i..i + fragment_size);
304
305            let chunk = ChunkPayloadData {
306                stream_identifier: self.stream_identifier,
307                user_data,
308                unordered,
309                beginning_fragment: i == 0,
310                ending_fragment: remaining - fragment_size == 0,
311                immediate_sack: false,
312                payload_type: ppi,
313                stream_sequence_number: self.sequence_number.load(Ordering::SeqCst),
314                abandoned: head_abandoned.clone(), // all fragmented chunks use the same abandoned
315                all_inflight: head_all_inflight.clone(), // all fragmented chunks use the same all_inflight
316                ..Default::default()
317            };
318
319            chunks.push(chunk);
320
321            remaining -= fragment_size;
322            i += fragment_size;
323        }
324
325        // RFC 4960 Sec 6.6
326        // Note: When transmitting ordered and unordered data, an endpoint does
327        // not increment its Stream Sequence Number when transmitting a DATA
328        // chunk with U flag set to 1.
329        if !unordered {
330            self.sequence_number.fetch_add(1, Ordering::SeqCst);
331        }
332
333        let old_value = self.buffered_amount.fetch_add(raw.len(), Ordering::SeqCst);
334        log::trace!("[{}] bufferedAmount = {}", self.name, old_value + raw.len());
335
336        chunks
337    }
338
339    /// Close closes the write-direction of the stream.
340    /// Future calls to write are not permitted after calling Close.
341    pub async fn close(&self) -> Result<()> {
342        if !self.closed.load(Ordering::SeqCst) {
343            // Reset the outgoing stream
344            // https://tools.ietf.org/html/rfc6525
345            self.send_reset_request(self.stream_identifier).await?;
346        }
347        self.closed.store(true, Ordering::SeqCst);
348        self.read_notifier.notify_waiters(); // broadcast regardless
349
350        Ok(())
351    }
352
353    /// buffered_amount returns the number of bytes of data currently queued to be sent over this stream.
354    pub fn buffered_amount(&self) -> usize {
355        self.buffered_amount.load(Ordering::SeqCst)
356    }
357
358    /// buffered_amount_low_threshold returns the number of bytes of buffered outgoing data that is
359    /// considered "low." Defaults to 0.
360    pub fn buffered_amount_low_threshold(&self) -> usize {
361        self.buffered_amount_low.load(Ordering::SeqCst)
362    }
363
364    /// set_buffered_amount_low_threshold is used to update the threshold.
365    /// See buffered_amount_low_threshold().
366    pub fn set_buffered_amount_low_threshold(&self, th: usize) {
367        self.buffered_amount_low.store(th, Ordering::SeqCst);
368    }
369
370    /// on_buffered_amount_low sets the callback handler which would be called when the number of
371    /// bytes of outgoing data buffered is lower than the threshold.
372    pub async fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) {
373        let mut on_buffered_amount_low = self.on_buffered_amount_low.lock().await;
374        *on_buffered_amount_low = Some(f);
375    }
376
377    /// This method is called by association's read_loop (go-)routine to notify this stream
378    /// of the specified amount of outgoing data has been delivered to the peer.
379    pub(crate) async fn on_buffer_released(&self, n_bytes_released: i64) {
380        if n_bytes_released <= 0 {
381            return;
382        }
383
384        let from_amount = self.buffered_amount.load(Ordering::SeqCst);
385        let new_amount = if from_amount < n_bytes_released as usize {
386            self.buffered_amount.store(0, Ordering::SeqCst);
387            log::error!(
388                "[{}] released buffer size {} should be <= {}",
389                self.name,
390                n_bytes_released,
391                0,
392            );
393            0
394        } else {
395            self.buffered_amount
396                .fetch_sub(n_bytes_released as usize, Ordering::SeqCst);
397
398            from_amount - n_bytes_released as usize
399        };
400
401        let buffered_amount_low = self.buffered_amount_low.load(Ordering::SeqCst);
402
403        log::trace!(
404            "[{}] bufferedAmount = {}, from_amount = {}, buffered_amount_low = {}",
405            self.name,
406            new_amount,
407            from_amount,
408            buffered_amount_low,
409        );
410
411        if from_amount > buffered_amount_low && new_amount <= buffered_amount_low {
412            let mut handler = self.on_buffered_amount_low.lock().await;
413            if let Some(f) = &mut *handler {
414                f().await;
415            }
416        }
417    }
418
419    pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize {
420        // No lock is required as it reads the size with atomic load function.
421        let reassembly_queue = self.reassembly_queue.lock().await;
422        reassembly_queue.get_num_bytes()
423    }
424
425    /// get_state atomically returns the state of the Association.
426    fn get_state(&self) -> AssociationState {
427        self.state.load(Ordering::SeqCst).into()
428    }
429
430    fn awake_write_loop(&self) {
431        //log::debug!("[{}] awake_write_loop_ch.notify_one", self.name);
432        if let Some(awake_write_loop_ch) = &self.awake_write_loop_ch {
433            let _ = awake_write_loop_ch.try_send(());
434        }
435    }
436
437    async fn send_payload_data(&self, chunks: Vec<ChunkPayloadData>) -> Result<()> {
438        let state = self.get_state();
439        if state != AssociationState::Established {
440            return Err(Error::ErrPayloadDataStateNotExist);
441        }
442
443        // Push the chunks into the pending queue first.
444        for c in chunks {
445            self.pending_queue.push(c).await;
446        }
447
448        self.awake_write_loop();
449        Ok(())
450    }
451
452    async fn send_reset_request(&self, stream_identifier: u16) -> Result<()> {
453        let state = self.get_state();
454        if state != AssociationState::Established {
455            return Err(Error::ErrResetPacketInStateNotExist);
456        }
457
458        // Create DATA chunk which only contains valid stream identifier with
459        // nil userData and use it as a EOS from the stream.
460        let c = ChunkPayloadData {
461            stream_identifier,
462            beginning_fragment: true,
463            ending_fragment: true,
464            user_data: Bytes::new(),
465            ..Default::default()
466        };
467
468        self.pending_queue.push(c).await;
469
470        self.awake_write_loop();
471        Ok(())
472    }
473}