1#[cfg(test)]
2mod stream_test;
3
4use crate::association::AssociationState;
5use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier};
6use crate::error::{Error, Result};
7use crate::queue::reassembly_queue::ReassemblyQueue;
8
9use crate::queue::pending_queue::PendingQueue;
10
11use bytes::Bytes;
12use std::fmt;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicU8, AtomicUsize, Ordering};
16use std::sync::Arc;
17use tokio::sync::{mpsc, Mutex, Notify};
18
19#[derive(Debug, Copy, Clone, PartialEq)]
20#[repr(C)]
21pub enum ReliabilityType {
22 Reliable = 0,
24 Rexmit = 1,
26 Timed = 2,
28}
29
30impl Default for ReliabilityType {
31 fn default() -> Self {
32 ReliabilityType::Reliable
33 }
34}
35
36impl fmt::Display for ReliabilityType {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 let s = match *self {
39 ReliabilityType::Reliable => "Reliable",
40 ReliabilityType::Rexmit => "Rexmit",
41 ReliabilityType::Timed => "Timed",
42 };
43 write!(f, "{}", s)
44 }
45}
46
47impl From<u8> for ReliabilityType {
48 fn from(v: u8) -> ReliabilityType {
49 match v {
50 1 => ReliabilityType::Rexmit,
51 2 => ReliabilityType::Timed,
52 _ => ReliabilityType::Reliable,
53 }
54 }
55}
56
57pub type OnBufferedAmountLowFn =
58 Box<dyn (FnMut() -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>) + Send + Sync>;
59
60#[derive(Default)]
64pub struct Stream {
65 pub(crate) max_payload_size: u32,
66 pub(crate) max_message_size: Arc<AtomicU32>, pub(crate) state: Arc<AtomicU8>, pub(crate) awake_write_loop_ch: Option<Arc<mpsc::Sender<()>>>,
69 pub(crate) pending_queue: Arc<PendingQueue>,
70
71 pub(crate) stream_identifier: u16,
72 pub(crate) default_payload_type: AtomicU32, pub(crate) reassembly_queue: Mutex<ReassemblyQueue>,
74 pub(crate) sequence_number: AtomicU16,
75 pub(crate) read_notifier: Notify,
76 pub(crate) closed: AtomicBool,
77 pub(crate) unordered: AtomicBool,
78 pub(crate) reliability_type: AtomicU8, pub(crate) reliability_value: AtomicU32,
80 pub(crate) buffered_amount: AtomicUsize,
81 pub(crate) buffered_amount_low: AtomicUsize,
82 pub(crate) on_buffered_amount_low: Mutex<Option<OnBufferedAmountLowFn>>,
83 pub(crate) name: String,
84}
85
86impl fmt::Debug for Stream {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 f.debug_struct("Stream")
89 .field("max_payload_size", &self.max_payload_size)
90 .field("max_message_size", &self.max_message_size)
91 .field("state", &self.state)
92 .field("awake_write_loop_ch", &self.awake_write_loop_ch)
93 .field("stream_identifier", &self.stream_identifier)
94 .field("default_payload_type", &self.default_payload_type)
95 .field("reassembly_queue", &self.reassembly_queue)
96 .field("sequence_number", &self.sequence_number)
97 .field("closed", &self.closed)
98 .field("unordered", &self.unordered)
99 .field("reliability_type", &self.reliability_type)
100 .field("reliability_value", &self.reliability_value)
101 .field("buffered_amount", &self.buffered_amount)
102 .field("buffered_amount_low", &self.buffered_amount_low)
103 .field("name", &self.name)
104 .finish()
105 }
106}
107
108impl Stream {
109 pub(crate) fn new(
110 name: String,
111 stream_identifier: u16,
112 max_payload_size: u32,
113 max_message_size: Arc<AtomicU32>,
114 state: Arc<AtomicU8>,
115 awake_write_loop_ch: Option<Arc<mpsc::Sender<()>>>,
116 pending_queue: Arc<PendingQueue>,
117 ) -> Self {
118 Stream {
119 max_payload_size,
120 max_message_size,
121 state,
122 awake_write_loop_ch,
123 pending_queue,
124
125 stream_identifier,
126 default_payload_type: AtomicU32::new(0), reassembly_queue: Mutex::new(ReassemblyQueue::new(stream_identifier)),
128 sequence_number: AtomicU16::new(0),
129 read_notifier: Notify::new(),
130 closed: AtomicBool::new(false),
131 unordered: AtomicBool::new(false),
132 reliability_type: AtomicU8::new(0), reliability_value: AtomicU32::new(0),
134 buffered_amount: AtomicUsize::new(0),
135 buffered_amount_low: AtomicUsize::new(0),
136 on_buffered_amount_low: Mutex::new(None),
137 name,
138 }
139 }
140
141 pub fn stream_identifier(&self) -> u16 {
143 self.stream_identifier
144 }
145
146 pub fn set_default_payload_type(&self, default_payload_type: PayloadProtocolIdentifier) {
148 self.default_payload_type
149 .store(default_payload_type as u32, Ordering::SeqCst);
150 }
151
152 pub fn set_reliability_params(&self, unordered: bool, rel_type: ReliabilityType, rel_val: u32) {
154 log::debug!(
155 "[{}] reliability params: ordered={} type={} value={}",
156 self.name,
157 !unordered,
158 rel_type,
159 rel_val
160 );
161 self.unordered.store(unordered, Ordering::SeqCst);
162 self.reliability_type
163 .store(rel_type as u8, Ordering::SeqCst);
164 self.reliability_value.store(rel_val, Ordering::SeqCst);
165 }
166
167 pub async fn read(&self, p: &mut [u8]) -> Result<usize> {
171 let (n, _) = self.read_sctp(p).await?;
172 Ok(n)
173 }
174
175 pub async fn read_sctp(&self, p: &mut [u8]) -> Result<(usize, PayloadProtocolIdentifier)> {
180 while !self.closed.load(Ordering::SeqCst) {
181 let result = {
182 let mut reassembly_queue = self.reassembly_queue.lock().await;
183 reassembly_queue.read(p)
184 };
185
186 if result.is_ok() {
187 return result;
188 } else if let Err(err) = result {
189 if Error::ErrShortBuffer == err {
190 return Err(err);
191 }
192 }
193
194 self.read_notifier.notified().await;
195 }
196
197 Err(Error::ErrStreamClosed)
198 }
199
200 pub(crate) async fn handle_data(&self, pd: ChunkPayloadData) {
201 let readable = {
202 let mut reassembly_queue = self.reassembly_queue.lock().await;
203 if reassembly_queue.push(pd) {
204 let readable = reassembly_queue.is_readable();
205 log::debug!("[{}] reassemblyQueue readable={}", self.name, readable);
206 readable
207 } else {
208 false
209 }
210 };
211
212 if readable {
213 log::debug!("[{}] readNotifier.signal()", self.name);
214 self.read_notifier.notify_one();
215 log::debug!("[{}] readNotifier.signal() done", self.name);
216 }
217 }
218
219 pub(crate) async fn handle_forward_tsn_for_ordered(&self, ssn: u16) {
220 if self.unordered.load(Ordering::SeqCst) {
221 return; }
223
224 let readable = {
227 let mut reassembly_queue = self.reassembly_queue.lock().await;
228 reassembly_queue.forward_tsn_for_ordered(ssn);
229 reassembly_queue.is_readable()
230 };
231
232 if readable {
234 self.read_notifier.notify_one();
235 }
236 }
237
238 pub(crate) async fn handle_forward_tsn_for_unordered(&self, new_cumulative_tsn: u32) {
239 if !self.unordered.load(Ordering::SeqCst) {
240 return; }
242
243 let readable = {
246 let mut reassembly_queue = self.reassembly_queue.lock().await;
247 reassembly_queue.forward_tsn_for_unordered(new_cumulative_tsn);
248 reassembly_queue.is_readable()
249 };
250
251 if readable {
253 self.read_notifier.notify_one();
254 }
255 }
256
257 pub async fn write(&self, p: &Bytes) -> Result<usize> {
259 self.write_sctp(p, self.default_payload_type.load(Ordering::SeqCst).into())
260 .await
261 }
262
263 pub async fn write_sctp(&self, p: &Bytes, ppi: PayloadProtocolIdentifier) -> Result<usize> {
265 if p.len() > self.max_message_size.load(Ordering::SeqCst) as usize {
266 return Err(Error::ErrOutboundPacketTooLarge);
267 }
268
269 let state: AssociationState = self.state.load(Ordering::SeqCst).into();
270 match state {
271 AssociationState::ShutdownSent
272 | AssociationState::ShutdownAckSent
273 | AssociationState::ShutdownPending
274 | AssociationState::ShutdownReceived => return Err(Error::ErrStreamClosed),
275 _ => {}
276 };
277
278 let chunks = self.packetize(p, ppi);
279 self.send_payload_data(chunks).await?;
280
281 Ok(p.len())
282 }
283
284 fn packetize(&self, raw: &Bytes, ppi: PayloadProtocolIdentifier) -> Vec<ChunkPayloadData> {
285 let mut i = 0;
286 let mut remaining = raw.len();
287
288 let unordered =
292 ppi != PayloadProtocolIdentifier::Dcep && self.unordered.load(Ordering::SeqCst);
293
294 let mut chunks = vec![];
295
296 let head_abandoned = Arc::new(AtomicBool::new(false));
297 let head_all_inflight = Arc::new(AtomicBool::new(false));
298 while remaining != 0 {
299 let fragment_size = std::cmp::min(self.max_payload_size as usize, remaining); let user_data = raw.slice(i..i + fragment_size);
304
305 let chunk = ChunkPayloadData {
306 stream_identifier: self.stream_identifier,
307 user_data,
308 unordered,
309 beginning_fragment: i == 0,
310 ending_fragment: remaining - fragment_size == 0,
311 immediate_sack: false,
312 payload_type: ppi,
313 stream_sequence_number: self.sequence_number.load(Ordering::SeqCst),
314 abandoned: head_abandoned.clone(), all_inflight: head_all_inflight.clone(), ..Default::default()
317 };
318
319 chunks.push(chunk);
320
321 remaining -= fragment_size;
322 i += fragment_size;
323 }
324
325 if !unordered {
330 self.sequence_number.fetch_add(1, Ordering::SeqCst);
331 }
332
333 let old_value = self.buffered_amount.fetch_add(raw.len(), Ordering::SeqCst);
334 log::trace!("[{}] bufferedAmount = {}", self.name, old_value + raw.len());
335
336 chunks
337 }
338
339 pub async fn close(&self) -> Result<()> {
342 if !self.closed.load(Ordering::SeqCst) {
343 self.send_reset_request(self.stream_identifier).await?;
346 }
347 self.closed.store(true, Ordering::SeqCst);
348 self.read_notifier.notify_waiters(); Ok(())
351 }
352
353 pub fn buffered_amount(&self) -> usize {
355 self.buffered_amount.load(Ordering::SeqCst)
356 }
357
358 pub fn buffered_amount_low_threshold(&self) -> usize {
361 self.buffered_amount_low.load(Ordering::SeqCst)
362 }
363
364 pub fn set_buffered_amount_low_threshold(&self, th: usize) {
367 self.buffered_amount_low.store(th, Ordering::SeqCst);
368 }
369
370 pub async fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) {
373 let mut on_buffered_amount_low = self.on_buffered_amount_low.lock().await;
374 *on_buffered_amount_low = Some(f);
375 }
376
377 pub(crate) async fn on_buffer_released(&self, n_bytes_released: i64) {
380 if n_bytes_released <= 0 {
381 return;
382 }
383
384 let from_amount = self.buffered_amount.load(Ordering::SeqCst);
385 let new_amount = if from_amount < n_bytes_released as usize {
386 self.buffered_amount.store(0, Ordering::SeqCst);
387 log::error!(
388 "[{}] released buffer size {} should be <= {}",
389 self.name,
390 n_bytes_released,
391 0,
392 );
393 0
394 } else {
395 self.buffered_amount
396 .fetch_sub(n_bytes_released as usize, Ordering::SeqCst);
397
398 from_amount - n_bytes_released as usize
399 };
400
401 let buffered_amount_low = self.buffered_amount_low.load(Ordering::SeqCst);
402
403 log::trace!(
404 "[{}] bufferedAmount = {}, from_amount = {}, buffered_amount_low = {}",
405 self.name,
406 new_amount,
407 from_amount,
408 buffered_amount_low,
409 );
410
411 if from_amount > buffered_amount_low && new_amount <= buffered_amount_low {
412 let mut handler = self.on_buffered_amount_low.lock().await;
413 if let Some(f) = &mut *handler {
414 f().await;
415 }
416 }
417 }
418
419 pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize {
420 let reassembly_queue = self.reassembly_queue.lock().await;
422 reassembly_queue.get_num_bytes()
423 }
424
425 fn get_state(&self) -> AssociationState {
427 self.state.load(Ordering::SeqCst).into()
428 }
429
430 fn awake_write_loop(&self) {
431 if let Some(awake_write_loop_ch) = &self.awake_write_loop_ch {
433 let _ = awake_write_loop_ch.try_send(());
434 }
435 }
436
437 async fn send_payload_data(&self, chunks: Vec<ChunkPayloadData>) -> Result<()> {
438 let state = self.get_state();
439 if state != AssociationState::Established {
440 return Err(Error::ErrPayloadDataStateNotExist);
441 }
442
443 for c in chunks {
445 self.pending_queue.push(c).await;
446 }
447
448 self.awake_write_loop();
449 Ok(())
450 }
451
452 async fn send_reset_request(&self, stream_identifier: u16) -> Result<()> {
453 let state = self.get_state();
454 if state != AssociationState::Established {
455 return Err(Error::ErrResetPacketInStateNotExist);
456 }
457
458 let c = ChunkPayloadData {
461 stream_identifier,
462 beginning_fragment: true,
463 ending_fragment: true,
464 user_data: Bytes::new(),
465 ..Default::default()
466 };
467
468 self.pending_queue.push(c).await;
469
470 self.awake_write_loop();
471 Ok(())
472 }
473}