1use std::collections::VecDeque;
17use std::time::{Duration, Instant};
18
19use bytes::{BufMut, Bytes, BytesMut};
20
21use crate::protocol::OpCode;
22use crate::rc4::Rc4KeyState;
23
24use super::true_incoming_sequence;
25
26const SEQUENCE_SIZE: usize = 2;
28const FRAGMENT_LENGTH_SIZE: usize = 4;
30
31const RTT_K: u32 = 4;
36const RTO_GRANULARITY: Duration = Duration::from_millis(100);
39const RTO_MIN: Duration = Duration::from_millis(200);
41const RTO_MAX: Duration = Duration::from_secs(8);
43
44#[derive(Debug, Default, Clone)]
46pub struct DataOutputStats {
47 pub total_sent: u64,
49 pub total_resent: u64,
51 pub incoming_acknowledge_count: u64,
53 pub actual_acknowledge_count: u64,
55}
56
57#[derive(Debug, Clone)]
59pub struct OutputConfig {
60 pub max_data_length: usize,
64 pub max_queued_outgoing: usize,
67 pub ack_wait: Duration,
72}
73
74impl Default for OutputConfig {
75 fn default() -> Self {
76 Self {
77 max_data_length: 508,
78 max_queued_outgoing: 196,
79 ack_wait: Duration::from_millis(500),
80 }
81 }
82}
83
84#[derive(Debug, Clone, PartialEq, Eq)]
87pub struct OutgoingReliable {
88 pub op_code: OpCode,
91 pub payload: Bytes,
94}
95
96#[derive(Debug)]
97struct StashedOutputPacket {
98 is_fragment: bool,
99 data: Bytes,
100 sent: bool,
101 sent_at: Option<Instant>,
104 resent: bool,
107}
108
109#[derive(Debug)]
111pub struct ReliableDataOutputChannel {
112 config: OutputConfig,
113 cipher: Option<Rc4KeyState>,
114
115 dispatch_queue: VecDeque<(i64, StashedOutputPacket)>,
116
117 total_sequence: i64,
119 max_client_sequence: i64,
121 current_dispatch_index: usize,
123
124 srtt: Option<Duration>,
126 rttvar: Duration,
128 rto: Duration,
130
131 outgoing: Vec<OutgoingReliable>,
132 stats: DataOutputStats,
133}
134
135impl ReliableDataOutputChannel {
136 pub fn new(config: OutputConfig, cipher: Option<Rc4KeyState>, _now: Instant) -> Self {
140 let initial_rto = config.ack_wait;
141 Self {
142 config,
143 cipher,
144 dispatch_queue: VecDeque::new(),
145 total_sequence: 0,
146 max_client_sequence: 0,
147 current_dispatch_index: 0,
148 srtt: None,
149 rttvar: Duration::ZERO,
150 rto: initial_rto,
151 outgoing: Vec::new(),
152 stats: DataOutputStats::default(),
153 }
154 }
155
156 pub fn stats(&self) -> &DataOutputStats {
158 &self.stats
159 }
160
161 pub fn take_outgoing(&mut self) -> Vec<OutgoingReliable> {
163 std::mem::take(&mut self.outgoing)
164 }
165
166 pub fn queued_len(&self) -> usize {
168 self.dispatch_queue.len()
169 }
170
171 pub fn set_max_data_length(&mut self, max_data_length: usize) {
174 self.config.max_data_length = max_data_length;
175 }
176
177 fn max_chunk(&self) -> usize {
178 self.config.max_data_length - SEQUENCE_SIZE
179 }
180
181 pub fn enqueue_data(&mut self, data: &[u8]) {
184 if data.is_empty() {
185 return;
186 }
187
188 let mut remaining: Bytes = match &mut self.cipher {
189 Some(_) => self.encrypt(data),
190 None => Bytes::copy_from_slice(data),
191 };
192
193 let is_fragment = remaining.len() > self.max_chunk();
194 self.stash_fragment(&mut remaining, true, is_fragment);
195 while !remaining.is_empty() {
196 self.stash_fragment(&mut remaining, false, true);
197 }
198 }
199
200 pub fn run_tick(&mut self, now: Instant) {
205 let timed_out = match self.dispatch_queue.front() {
209 Some((_, front)) if front.sent => front
210 .sent_at
211 .is_some_and(|sent_at| now.duration_since(sent_at) > self.rto),
212 _ => false,
213 };
214 if timed_out {
215 self.current_dispatch_index = 0;
216 self.rto = (self.rto * 2).min(RTO_MAX);
218 }
219
220 let max_index = self
221 .dispatch_queue
222 .len()
223 .min(self.config.max_queued_outgoing);
224
225 while self.current_dispatch_index < max_index {
226 let (_, packet) = &mut self.dispatch_queue[self.current_dispatch_index];
227 let op_code = if packet.is_fragment {
228 OpCode::ReliableDataFragment
229 } else {
230 OpCode::ReliableData
231 };
232
233 self.stats.total_sent += 1;
234 if packet.sent {
235 self.stats.total_resent += 1;
236 packet.resent = true;
238 }
239 packet.sent = true;
240 packet.sent_at = Some(now);
241
242 let payload = packet.data.clone();
243 self.outgoing.push(OutgoingReliable { op_code, payload });
244 self.current_dispatch_index += 1;
245 }
246 }
247
248 fn update_rto(&mut self, sample: Duration) {
251 match self.srtt {
252 None => {
253 self.srtt = Some(sample);
254 self.rttvar = sample / 2;
255 }
256 Some(srtt) => {
257 let diff = srtt.abs_diff(sample);
258 self.rttvar = (self.rttvar * 3 + diff) / 4;
260 self.srtt = Some((srtt * 7 + sample) / 8);
262 }
263 }
264 let srtt = self.srtt.unwrap_or(sample);
265 let rto = srtt + std::cmp::max(RTO_GRANULARITY, self.rttvar * RTT_K);
266 self.rto = rto.clamp(RTO_MIN, RTO_MAX);
267 }
268
269 pub fn notify_of_acknowledge(&mut self, sequence: u16, now: Instant) {
271 let seq = self.true_incoming(sequence);
272 self.stats.incoming_acknowledge_count += 1;
273
274 if let Some(pos) = self.dispatch_queue.iter().position(|(s, _)| *s == seq) {
275 let (_, pkt) = &self.dispatch_queue[pos];
276 let sample = (pkt.sent && !pkt.resent)
277 .then(|| pkt.sent_at.map(|sent_at| now.duration_since(sent_at)))
278 .flatten();
279 self.dispatch_queue.remove(pos);
280 self.current_dispatch_index = self.current_dispatch_index.saturating_sub(1);
281 self.stats.actual_acknowledge_count += 1;
282 if let Some(sample) = sample {
283 self.update_rto(sample);
284 }
285 }
286
287 if seq > self.max_client_sequence {
288 self.max_client_sequence = seq;
289 }
290 }
291
292 pub fn notify_of_acknowledge_all(&mut self, sequence: u16, now: Instant) {
295 let seq = self.true_incoming(sequence);
296 self.stats.incoming_acknowledge_count += 1;
297
298 let mut sample: Option<Duration> = None;
299 loop {
300 let (pop, this_sample) = match self.dispatch_queue.front() {
301 Some((s, pkt)) if *s <= seq => {
302 let smp = (pkt.sent && !pkt.resent)
303 .then(|| pkt.sent_at.map(|sent_at| now.duration_since(sent_at)))
304 .flatten();
305 (true, smp)
306 }
307 _ => (false, None),
308 };
309 if !pop {
310 break;
311 }
312 if this_sample.is_some() {
314 sample = this_sample;
315 }
316 self.dispatch_queue.pop_front();
317 self.current_dispatch_index = self.current_dispatch_index.saturating_sub(1);
318 self.stats.actual_acknowledge_count += 1;
319 }
320
321 if let Some(sample) = sample {
322 self.update_rto(sample);
323 }
324
325 if seq > self.max_client_sequence {
326 self.max_client_sequence = seq;
327 }
328 }
329
330 fn stash_fragment(&mut self, data: &mut Bytes, is_master: bool, is_fragment: bool) {
331 let mut amount = data.len().min(self.max_chunk());
332
333 let mut buf = BytesMut::with_capacity(SEQUENCE_SIZE + FRAGMENT_LENGTH_SIZE + amount);
334 buf.put_u16(self.total_sequence as u16);
335
336 if is_master && is_fragment {
337 buf.put_u32(data.len() as u32);
338 amount -= FRAGMENT_LENGTH_SIZE;
339 }
340
341 buf.extend_from_slice(&data[..amount]);
342
343 self.dispatch_queue.push_back((
344 self.total_sequence,
345 StashedOutputPacket {
346 is_fragment,
347 data: buf.freeze(),
348 sent: false,
349 sent_at: None,
350 resent: false,
351 },
352 ));
353
354 self.total_sequence += 1;
355 *data = data.slice(amount..);
356 }
357
358 fn encrypt(&mut self, data: &[u8]) -> Bytes {
362 let cipher = self
363 .cipher
364 .as_mut()
365 .expect("encrypt called without a cipher");
366
367 let mut buf = BytesMut::with_capacity(data.len() + 1);
368 buf.put_u8(0);
369 buf.extend_from_slice(data);
370 cipher.transform_in_place(&mut buf[1..]);
371
372 let frozen = buf.freeze();
373 if frozen[1] == 0 {
374 frozen
375 } else {
376 frozen.slice(1..)
377 }
378 }
379
380 fn true_incoming(&self, packet_sequence: u16) -> i64 {
381 true_incoming_sequence(
382 packet_sequence,
383 self.max_client_sequence,
384 self.config.max_queued_outgoing as i64,
385 )
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 const MAX_DATA_LENGTH: usize = 506; const FRAGMENT_WINDOW_SIZE: usize = 8;
395
396 struct Clock {
397 now: Instant,
398 }
399
400 impl Clock {
401 fn new() -> Self {
402 Self {
403 now: Instant::now(),
404 }
405 }
406 fn advance(&mut self, by: Duration) -> Instant {
407 self.now += by;
408 self.now
409 }
410 }
411
412 fn new_channel(clock: &Clock) -> ReliableDataOutputChannel {
413 let config = OutputConfig {
414 max_data_length: MAX_DATA_LENGTH + SEQUENCE_SIZE,
415 max_queued_outgoing: FRAGMENT_WINDOW_SIZE,
416 ack_wait: Duration::from_millis(500),
417 };
418 ReliableDataOutputChannel::new(config, None, clock.now)
419 }
420
421 fn generate_packet(size: usize) -> Vec<u8> {
423 let mut state: u32 = 0x1234_5678 ^ size as u32;
424 (0..size)
425 .map(|_| {
426 state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
427 (state >> 24) as u8
428 })
429 .collect()
430 }
431
432 fn assert_packets_equal_buffer(
436 packets: &[OutgoingReliable],
437 buffer: &[u8],
438 mut expect_master_fragment: bool,
439 ) {
440 let mut position = 0;
441 for packet in packets {
442 let data_offset = SEQUENCE_SIZE
443 + if expect_master_fragment {
444 FRAGMENT_LENGTH_SIZE
445 } else {
446 0
447 };
448 expect_master_fragment = false;
449
450 let data = &packet.payload[data_offset..];
451 assert!(
452 position + data.len() <= buffer.len(),
453 "received more data than expected"
454 );
455 assert_eq!(&buffer[position..position + data.len()], data);
456 position += data.len();
457 }
458 assert_eq!(position, buffer.len(), "did not receive the whole buffer");
459 }
460
461 #[test]
462 fn repeats_data_on_ack_failure() {
463 let mut clock = Clock::new();
464 let mut ch = new_channel(&clock);
465
466 let fragment_count = 4;
467 let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (fragment_count - 1);
468 let packet = generate_packet(packet_length);
469
470 ch.enqueue_data(&packet);
471 ch.run_tick(clock.advance(Duration::from_millis(1)));
472 assert_packets_equal_buffer(&ch.take_outgoing(), &packet, true);
473
474 ch.run_tick(clock.advance(Duration::from_millis(600)));
476 assert_packets_equal_buffer(&ch.take_outgoing(), &packet, true);
477 }
478
479 #[test]
480 fn repeats_data_from_arbitrary_position_on_ack_delay() {
481 let mut clock = Clock::new();
482 let mut ch = new_channel(&clock);
483
484 let fragment_count = 4;
485 let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (fragment_count - 1);
486 let packet = generate_packet(packet_length);
487
488 ch.enqueue_data(&packet);
489 ch.run_tick(clock.advance(Duration::from_millis(1)));
490 assert_packets_equal_buffer(&ch.take_outgoing(), &packet, true);
491
492 ch.notify_of_acknowledge_all(1, clock.advance(Duration::from_millis(1)));
493
494 ch.run_tick(clock.advance(Duration::from_millis(600)));
495 let expected_consumed = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH;
497 assert_packets_equal_buffer(&ch.take_outgoing(), &packet[expected_consumed..], false);
498 }
499
500 #[test]
501 fn repeats_full_window_from_arbitrary_position_on_ack_delay() {
502 let mut clock = Clock::new();
503 let mut ch = new_channel(&clock);
504
505 let fragment_count = FRAGMENT_WINDOW_SIZE * 2;
506 let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (fragment_count - 1);
507 let packet = generate_packet(packet_length);
508
509 ch.enqueue_data(&packet);
510 ch.run_tick(clock.advance(Duration::from_millis(1)));
511
512 let expected_receive_length =
514 MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (FRAGMENT_WINDOW_SIZE - 1);
515 assert_packets_equal_buffer(
516 &ch.take_outgoing(),
517 &packet[..expected_receive_length],
518 true,
519 );
520
521 ch.notify_of_acknowledge_all(
522 (FRAGMENT_WINDOW_SIZE - 2) as u16,
523 clock.advance(Duration::from_millis(1)),
524 );
525 ch.run_tick(clock.advance(Duration::from_millis(600)));
526
527 let expected_consumed = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (FRAGMENT_WINDOW_SIZE - 2);
528 let expected_repeat_length = MAX_DATA_LENGTH * FRAGMENT_WINDOW_SIZE;
529 assert_packets_equal_buffer(
530 &ch.take_outgoing(),
531 &packet[expected_consumed..expected_consumed + expected_repeat_length],
532 false,
533 );
534 }
535
536 #[test]
537 fn single_small_packet_is_not_fragmented() {
538 let mut clock = Clock::new();
539 let mut ch = new_channel(&clock);
540
541 let data = generate_packet(32);
542 ch.enqueue_data(&data);
543 ch.run_tick(clock.advance(Duration::from_millis(1)));
544
545 let outgoing = ch.take_outgoing();
546 assert_eq!(outgoing.len(), 1);
547 assert_eq!(outgoing[0].op_code, OpCode::ReliableData);
548 assert_eq!(&outgoing[0].payload[SEQUENCE_SIZE..], &data[..]);
550 }
551
552 #[test]
553 fn single_ack_removes_specific_packet() {
554 let mut clock = Clock::new();
555 let mut ch = new_channel(&clock);
556
557 let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * 3;
558 let packet = generate_packet(packet_length);
559 ch.enqueue_data(&packet);
560 assert_eq!(ch.queued_len(), 4);
561
562 ch.run_tick(clock.advance(Duration::from_millis(1)));
563 let _ = ch.take_outgoing();
564
565 ch.notify_of_acknowledge(2, clock.advance(Duration::from_millis(1)));
566 assert_eq!(ch.queued_len(), 3);
567 assert_eq!(ch.stats().actual_acknowledge_count, 1);
568 }
569
570 #[test]
575 fn window_does_not_grow_across_ticks_without_ack() {
576 let mut clock = Clock::new();
577 let mut ch = new_channel(&clock);
578
579 let fragment_count = FRAGMENT_WINDOW_SIZE * 4;
581 let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (fragment_count - 1);
582 let packet = generate_packet(packet_length);
583 ch.enqueue_data(&packet);
584
585 ch.run_tick(clock.advance(Duration::from_millis(1)));
587 let mut in_flight = ch.take_outgoing().len();
588 assert_eq!(
589 in_flight, FRAGMENT_WINDOW_SIZE,
590 "first tick should send exactly one window"
591 );
592
593 for _ in 0..5 {
596 ch.run_tick(clock.advance(Duration::from_millis(10)));
597 in_flight += ch.take_outgoing().len();
598 assert!(
599 in_flight <= FRAGMENT_WINDOW_SIZE,
600 "in-flight unacked packets ({in_flight}) exceeded the window ({FRAGMENT_WINDOW_SIZE})",
601 );
602 }
603 }
604
605 #[test]
610 fn adaptive_rto_suppresses_resend_after_learning_high_rtt() {
611 let mut clock = Clock::new();
612 let mut ch = new_channel(&clock); let fragment_count = FRAGMENT_WINDOW_SIZE + 4;
616 let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (fragment_count - 1);
617 let packet = generate_packet(packet_length);
618 ch.enqueue_data(&packet);
619 ch.run_tick(clock.advance(Duration::from_millis(1)));
620 let _ = ch.take_outgoing();
621
622 ch.notify_of_acknowledge_all(
625 (FRAGMENT_WINDOW_SIZE - 1) as u16,
626 clock.advance(Duration::from_millis(500)),
627 );
628
629 ch.run_tick(clock.advance(Duration::from_millis(1)));
633 let _ = ch.take_outgoing();
634 ch.run_tick(clock.advance(Duration::from_millis(600)));
635 let resent = ch.take_outgoing();
636
637 assert!(
638 resent.is_empty(),
639 "adaptive RTO must not resend within the learned RTT, but resent {} packets",
640 resent.len()
641 );
642 assert_eq!(
643 ch.stats().total_resent,
644 0,
645 "no packet should have been retransmitted after the RTO adapted to the RTT"
646 );
647 }
648
649 #[test]
653 fn adaptive_rto_bounds_inflight_at_high_rtt() {
654 let mut clock = Clock::new();
655 let mut ch = new_channel(&clock);
656
657 let one_way = Duration::from_millis(250); let tick = Duration::from_millis(5);
659
660 let fragment_count = 30;
661 let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (fragment_count - 1);
662 let packet = generate_packet(packet_length);
663 ch.enqueue_data(&packet);
664 let unique = ch.queued_len();
665
666 let mut to_client: Vec<(Instant, u16)> = Vec::new();
667 let mut to_server: Vec<(Instant, u16)> = Vec::new();
668 let mut received = vec![false; unique];
669
670 let mut total_on_wire = 0usize;
671 let mut highest_sent: i64 = -1;
672 let mut last_ack: i64 = -1;
673 let mut max_in_flight: i64 = 0;
674
675 for _ in 0..800 {
676 let now = clock.advance(tick);
677
678 to_server.retain(|&(at, ack)| {
680 if at <= now {
681 ch.notify_of_acknowledge_all(ack, now);
682 last_ack = last_ack.max(ack as i64);
683 false
684 } else {
685 true
686 }
687 });
688
689 let mut delivered_any = false;
692 to_client.retain(|&(at, seq)| {
693 if at <= now {
694 received[seq as usize] = true;
695 delivered_any = true;
696 false
697 } else {
698 true
699 }
700 });
701 if delivered_any {
702 let mut hw: i64 = -1;
703 for (seq, got) in received.iter().enumerate() {
704 if *got {
705 hw = seq as i64;
706 } else {
707 break;
708 }
709 }
710 if hw >= 0 {
711 to_server.push((now + one_way, hw as u16));
712 }
713 }
714
715 ch.run_tick(now);
716 for out in ch.take_outgoing() {
717 let seq = u16::from_be_bytes([out.payload[0], out.payload[1]]);
718 total_on_wire += 1;
719 highest_sent = highest_sent.max(seq as i64);
720 to_client.push((now + one_way, seq));
721 }
722
723 max_in_flight = max_in_flight.max(highest_sent - last_ack);
724
725 if last_ack >= 0 && last_ack as usize + 1 == unique {
726 break;
727 }
728 }
729
730 assert!(
731 last_ack >= 0 && last_ack as usize + 1 == unique,
732 "channel did not drain all {unique} packets (acked through {last_ack})"
733 );
734 assert!(
735 max_in_flight <= FRAGMENT_WINDOW_SIZE as i64 + 2,
736 "in-flight ({max_in_flight}) far exceeded the window ({FRAGMENT_WINDOW_SIZE}) -> resend storm",
737 );
738 assert!(
739 total_on_wire <= unique + unique / 4,
740 "sent {total_on_wire} datagrams for {unique} unique packets (>1.25x = resend storm)",
741 );
742 }
743}