1extern crate alloc;
24use alloc::collections::BTreeMap;
25use alloc::vec::Vec;
26use core::time::Duration;
27
28use crate::error::XrceError;
29use crate::header::StreamId;
30use crate::serial_number::SerialNumber16;
31use crate::submessages::{AckNackPayload, HeartbeatPayload};
32
33pub const DEFAULT_HEARTBEAT_PERIOD: Duration = Duration::from_millis(500);
36
37pub const SENDER_WINDOW_CAP: usize = 16;
40
41pub const RECEIVER_BUFFER_CAP: usize = 64;
45
46pub const RELIABLE_MAX_PAYLOAD: usize = 65_535;
48
49#[derive(Debug, Clone, Copy)]
51pub struct ReliableConfig {
52 pub heartbeat_period: Duration,
54 pub sender_window: usize,
56 pub receiver_buffer: usize,
58}
59
60impl Default for ReliableConfig {
61 fn default() -> Self {
62 Self {
63 heartbeat_period: DEFAULT_HEARTBEAT_PERIOD,
64 sender_window: SENDER_WINDOW_CAP,
65 receiver_buffer: RECEIVER_BUFFER_CAP,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct ReliableStreamState {
73 stream_id: StreamId,
74 config: ReliableConfig,
75
76 next_seq: SerialNumber16,
79 in_flight: BTreeMap<u16, Vec<u8>>,
82 last_heartbeat: Option<Duration>,
84
85 expected_seq: SerialNumber16,
88 received: BTreeMap<u16, Vec<u8>>,
90}
91
92impl ReliableStreamState {
93 #[must_use]
98 pub fn new(stream_id: StreamId, config: ReliableConfig) -> Self {
99 assert!(
100 stream_id.is_reliable(),
101 "ReliableStreamState requires reliable stream id (>=128)"
102 );
103 Self {
104 stream_id,
105 config,
106 next_seq: SerialNumber16::new(0),
107 in_flight: BTreeMap::new(),
108 last_heartbeat: None,
109 expected_seq: SerialNumber16::new(0),
110 received: BTreeMap::new(),
111 }
112 }
113
114 #[must_use]
116 pub fn stream_id(&self) -> StreamId {
117 self.stream_id
118 }
119
120 #[must_use]
122 pub fn in_flight_count(&self) -> usize {
123 self.in_flight.len()
124 }
125
126 #[must_use]
128 pub fn out_of_order_count(&self) -> usize {
129 self.received.len()
130 }
131
132 #[must_use]
134 pub fn expected(&self) -> SerialNumber16 {
135 self.expected_seq
136 }
137
138 pub fn submit(&mut self, payload: Vec<u8>) -> Result<SerialNumber16, XrceError> {
149 if payload.len() > RELIABLE_MAX_PAYLOAD {
150 return Err(XrceError::PayloadTooLarge {
151 limit: RELIABLE_MAX_PAYLOAD,
152 actual: payload.len(),
153 });
154 }
155 if self.in_flight.len() >= self.config.sender_window {
156 return Err(XrceError::ValueOutOfRange {
157 message: "reliable sender window full",
158 });
159 }
160 let seq = self.next_seq;
161 self.in_flight.insert(seq.raw(), payload);
162 self.next_seq = self.next_seq.next();
163 Ok(seq)
164 }
165
166 #[must_use]
168 pub fn get_in_flight(&self, seq: SerialNumber16) -> Option<&[u8]> {
169 self.in_flight.get(&seq.raw()).map(Vec::as_slice)
170 }
171
172 pub fn pending_heartbeat(&mut self, now: Duration) -> Option<HeartbeatPayload> {
175 if self.in_flight.is_empty() {
176 return None;
177 }
178 let due = match self.last_heartbeat {
179 None => true,
180 Some(t) => now.saturating_sub(t) >= self.config.heartbeat_period,
181 };
182 if !due {
183 return None;
184 }
185 self.last_heartbeat = Some(now);
186 let first = *self.in_flight.keys().next()?;
187 let last = *self.in_flight.keys().next_back()?;
188 Some(HeartbeatPayload {
189 first_unacked_seq_nr: first as i16,
192 last_unacked_seq_nr: last as i16,
193 stream_id: self.stream_id.0,
194 })
195 }
196
197 pub fn recv_acknack(&mut self, payload: AckNackPayload) {
206 let base = payload.first_unacked_seq_num as u16;
207 let bitmap = u16::from_le_bytes(payload.nack_bitmap);
208
209 let to_remove: Vec<u16> = self
211 .in_flight
212 .keys()
213 .copied()
214 .filter(|&k| {
215 let diff = base.wrapping_sub(k);
216 diff > 0 && diff < SerialNumber16::HALF_WINDOW
218 })
219 .collect();
220 for k in to_remove {
221 self.in_flight.remove(&k);
222 }
223
224 for i in 0u16..16 {
227 let seq = base.wrapping_add(i);
228 let bit = (bitmap >> i) & 1;
229 if bit == 0 {
230 self.in_flight.remove(&seq);
232 }
233 }
234 }
235
236 pub fn recv_data(&mut self, seq: SerialNumber16, payload: Vec<u8>) -> Result<(), XrceError> {
246 if seq.wrapping_lt(self.expected_seq) {
248 return Ok(()); }
250 if self.received.contains_key(&seq.raw()) {
251 return Ok(()); }
253 if self.received.len() >= self.config.receiver_buffer {
254 return Err(XrceError::ValueOutOfRange {
255 message: "reliable receiver buffer full",
256 });
257 }
258 self.received.insert(seq.raw(), payload);
259 Ok(())
260 }
261
262 pub fn drain_in_order(&mut self) -> Vec<(SerialNumber16, Vec<u8>)> {
265 let mut out = Vec::new();
266 loop {
267 let key = self.expected_seq.raw();
268 if let Some(payload) = self.received.remove(&key) {
269 out.push((self.expected_seq, payload));
270 self.expected_seq = self.expected_seq.next();
271 } else {
272 break;
273 }
274 }
275 out
276 }
277
278 #[must_use]
283 pub fn pending_acknack(&self, hint_last_seen: Option<SerialNumber16>) -> AckNackPayload {
284 let base = self.expected_seq;
285 let mut bitmap: u16 = 0;
286 for i in 0u16..16 {
291 let seq = base.next().0.wrapping_sub(1).wrapping_add(i);
292 let s = SerialNumber16::new(seq);
293 if let Some(h) = hint_last_seen {
295 if s.wrapping_gt(h) {
296 continue;
297 }
298 }
299 if !self.received.contains_key(&seq) {
301 bitmap |= 1u16 << i;
302 }
303 }
304 AckNackPayload {
305 first_unacked_seq_num: base.raw() as i16,
306 nack_bitmap: bitmap.to_le_bytes(),
307 stream_id: self.stream_id.0,
308 }
309 }
310
311 pub fn reset(&mut self) {
313 self.next_seq = SerialNumber16::new(0);
314 self.in_flight.clear();
315 self.last_heartbeat = None;
316 self.expected_seq = SerialNumber16::new(0);
317 self.received.clear();
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 #![allow(clippy::expect_used, clippy::unwrap_used)]
324 use super::*;
325
326 fn rs() -> ReliableStreamState {
327 ReliableStreamState::new(StreamId::BUILTIN_RELIABLE, ReliableConfig::default())
328 }
329
330 #[test]
331 fn submit_assigns_monotonic_seqnrs() {
332 let mut s = rs();
333 let s0 = s.submit(alloc::vec![1, 2]).unwrap();
334 let s1 = s.submit(alloc::vec![3, 4]).unwrap();
335 assert_eq!(s0.raw(), 0);
336 assert_eq!(s1.raw(), 1);
337 assert_eq!(s.in_flight_count(), 2);
338 }
339
340 #[test]
341 fn submit_rejects_payload_too_large() {
342 let mut s = rs();
343 let huge = alloc::vec![0u8; RELIABLE_MAX_PAYLOAD + 1];
344 assert!(matches!(
345 s.submit(huge),
346 Err(XrceError::PayloadTooLarge { .. })
347 ));
348 }
349
350 #[test]
351 fn submit_rejects_when_window_full() {
352 let mut s = rs();
353 for _ in 0..SENDER_WINDOW_CAP {
354 s.submit(alloc::vec![0]).unwrap();
355 }
356 assert!(s.submit(alloc::vec![0]).is_err());
357 }
358
359 #[test]
360 fn pending_heartbeat_fires_first_time() {
361 let mut s = rs();
362 s.submit(alloc::vec![1]).unwrap();
363 let hb = s.pending_heartbeat(Duration::from_secs(0));
364 assert!(hb.is_some());
365 let h = hb.unwrap();
366 assert_eq!(h.first_unacked_seq_nr, 0);
367 assert_eq!(h.last_unacked_seq_nr, 0);
368 assert_eq!(h.stream_id, StreamId::BUILTIN_RELIABLE.0);
369 }
370
371 #[test]
372 fn pending_heartbeat_silenced_until_period_elapsed() {
373 let mut s = rs();
374 s.submit(alloc::vec![1]).unwrap();
375 assert!(s.pending_heartbeat(Duration::from_millis(0)).is_some());
376 assert!(s.pending_heartbeat(Duration::from_millis(100)).is_none());
378 assert!(s.pending_heartbeat(Duration::from_millis(600)).is_some());
380 }
381
382 #[test]
383 fn pending_heartbeat_none_when_window_empty() {
384 let mut s = rs();
385 assert!(s.pending_heartbeat(Duration::from_secs(0)).is_none());
386 }
387
388 #[test]
389 fn recv_acknack_clears_acked_seqnrs() {
390 let mut s = rs();
391 s.submit(alloc::vec![0xA0]).unwrap(); s.submit(alloc::vec![0xA1]).unwrap(); s.submit(alloc::vec![0xA2]).unwrap(); assert_eq!(s.in_flight_count(), 3);
395 let ack = AckNackPayload {
398 first_unacked_seq_num: 2,
399 nack_bitmap: [0x01, 0x00],
400 stream_id: StreamId::BUILTIN_RELIABLE.0,
401 };
402 s.recv_acknack(ack);
403 assert_eq!(s.in_flight_count(), 1);
404 assert!(s.get_in_flight(SerialNumber16::new(2)).is_some());
405 }
406
407 #[test]
408 fn recv_acknack_full_clear_when_no_bits_set() {
409 let mut s = rs();
410 for _ in 0..5 {
411 s.submit(alloc::vec![0]).unwrap();
412 }
413 let ack = AckNackPayload {
415 first_unacked_seq_num: 5,
416 nack_bitmap: [0, 0],
417 stream_id: 0x80,
418 };
419 s.recv_acknack(ack);
420 assert_eq!(s.in_flight_count(), 0);
421 }
422
423 #[test]
424 fn recv_data_buffers_in_order() {
425 let mut s = rs();
426 s.recv_data(SerialNumber16::new(0), alloc::vec![10])
427 .unwrap();
428 s.recv_data(SerialNumber16::new(1), alloc::vec![11])
429 .unwrap();
430 let drained = s.drain_in_order();
431 assert_eq!(drained.len(), 2);
432 assert_eq!(drained[0].0.raw(), 0);
433 assert_eq!(drained[1].0.raw(), 1);
434 assert_eq!(s.expected().raw(), 2);
435 }
436
437 #[test]
438 fn recv_data_reorders_out_of_order() {
439 let mut s = rs();
440 s.recv_data(SerialNumber16::new(2), alloc::vec![22])
441 .unwrap();
442 s.recv_data(SerialNumber16::new(0), alloc::vec![20])
443 .unwrap();
444 let d1 = s.drain_in_order();
446 assert_eq!(d1.len(), 1);
447 assert_eq!(d1[0].0.raw(), 0);
448 s.recv_data(SerialNumber16::new(1), alloc::vec![21])
450 .unwrap();
451 let d2 = s.drain_in_order();
452 assert_eq!(d2.len(), 2);
453 assert_eq!(d2[0].0.raw(), 1);
454 assert_eq!(d2[1].0.raw(), 2);
455 }
456
457 #[test]
458 fn recv_data_drops_duplicates() {
459 let mut s = rs();
460 s.recv_data(SerialNumber16::new(0), alloc::vec![1]).unwrap();
461 s.drain_in_order();
462 s.recv_data(SerialNumber16::new(0), alloc::vec![99])
464 .unwrap();
465 assert_eq!(s.out_of_order_count(), 0);
466 }
467
468 #[test]
469 fn recv_data_rejects_when_buffer_full() {
470 let mut s = rs();
471 for i in 1..=RECEIVER_BUFFER_CAP as u16 {
473 s.recv_data(SerialNumber16::new(i), alloc::vec![1]).unwrap();
474 }
475 let res = s.recv_data(
476 SerialNumber16::new(RECEIVER_BUFFER_CAP as u16 + 1),
477 alloc::vec![1],
478 );
479 assert!(res.is_err());
480 }
481
482 #[test]
483 fn pending_acknack_marks_missing_slots() {
484 let mut s = rs();
485 s.recv_data(SerialNumber16::new(1), alloc::vec![1]).unwrap();
487 s.recv_data(SerialNumber16::new(3), alloc::vec![3]).unwrap();
488 let ack = s.pending_acknack(Some(SerialNumber16::new(3)));
489 let bitmap = u16::from_le_bytes(ack.nack_bitmap);
490 assert!(bitmap & (1 << 0) != 0);
492 assert!(bitmap & (1 << 2) != 0);
493 assert!(bitmap & (1 << 1) == 0); assert!(bitmap & (1 << 3) == 0); }
496
497 #[test]
498 fn reset_clears_state_completely() {
499 let mut s = rs();
500 s.submit(alloc::vec![1, 2]).unwrap();
501 s.recv_data(SerialNumber16::new(0), alloc::vec![3]).unwrap();
502 s.reset();
503 assert_eq!(s.in_flight_count(), 0);
504 assert_eq!(s.out_of_order_count(), 0);
505 assert_eq!(s.expected().raw(), 0);
506 }
507
508 #[test]
509 #[should_panic(expected = "reliable stream id")]
510 fn constructor_panics_on_best_effort_stream() {
511 let _ = ReliableStreamState::new(StreamId(1), ReliableConfig::default());
512 }
513
514 #[test]
527 fn end_to_end_sender_receiver_with_loss_recovery() {
528 let mut sender = ReliableStreamState::new(StreamId(0x80), ReliableConfig::default());
529 let mut receiver = ReliableStreamState::new(StreamId(0x80), ReliableConfig::default());
530
531 let s0 = sender.submit(alloc::vec![10]).expect("submit 0");
533 let s1 = sender.submit(alloc::vec![11]).expect("submit 1");
534 let s2 = sender.submit(alloc::vec![12]).expect("submit 2");
535 assert_eq!(sender.in_flight_count(), 3);
536
537 receiver.recv_data(s0, alloc::vec![10]).expect("recv s0");
539 receiver.recv_data(s2, alloc::vec![12]).expect("recv s2");
540
541 let drained = receiver.drain_in_order();
543 assert_eq!(drained.len(), 1);
544 assert_eq!(drained[0].1, alloc::vec![10]);
545
546 let acknack = receiver.pending_acknack(Some(s2));
548 sender.recv_acknack(acknack);
550 assert!(
555 sender.get_in_flight(s1).is_some(),
556 "s1 muss retransmittable sein"
557 );
558
559 let s1_payload = sender.get_in_flight(s1).expect("s1 retx").to_vec();
561 receiver.recv_data(s1, s1_payload).expect("recv retx s1");
562
563 let drained2 = receiver.drain_in_order();
565 assert_eq!(drained2.len(), 2);
566 assert_eq!(drained2[0].1, alloc::vec![11]);
567 assert_eq!(drained2[1].1, alloc::vec![12]);
568 }
569
570 #[test]
574 fn config_submessages_delivered_in_order_via_reliable_stream() {
575 let mut sender = ReliableStreamState::new(StreamId(0x80), ReliableConfig::default());
579 let mut receiver = ReliableStreamState::new(StreamId(0x80), ReliableConfig::default());
580
581 let mut seqs = Vec::new();
583 for i in 0..5u8 {
584 let seq = sender.submit(alloc::vec![i]).expect("submit");
585 seqs.push(seq);
586 }
587
588 let order = [2usize, 0, 4, 1, 3];
590 for idx in order {
591 receiver
592 .recv_data(seqs[idx], alloc::vec![idx as u8])
593 .expect("recv");
594 }
595
596 let drained = receiver.drain_in_order();
598 assert_eq!(drained.len(), 5);
599 for (i, (_, payload)) in drained.iter().enumerate() {
600 assert_eq!(payload, &alloc::vec![i as u8]);
601 }
602 }
603}