1use std::collections::BTreeSet;
44use std::time::{Duration, Instant};
45use tracing::{debug, info, warn};
46
47const DEFAULT_RTO_MS: u64 = 200;
48const AIMD_INCREASE_STEP: f64 = 1.0;
49const AIMD_DECREASE_FACTOR: f64 = 0.5;
50const CWND_MIN: f64 = 1.0;
51
52#[derive(Debug, Clone)]
54pub struct InFlightPacket {
55 pub sequence: u64,
57 pub sent_at: Instant,
59 pub retransmit_count: u32,
61 pub data: Vec<u8>,
63}
64
65impl InFlightPacket {
66 fn new(sequence: u64, data: Vec<u8>) -> Self {
67 InFlightPacket {
68 sequence,
69 sent_at: Instant::now(),
70 retransmit_count: 0,
71 data,
72 }
73 }
74
75 pub fn is_timed_out(&self, rto: Duration) -> bool {
77 self.sent_at.elapsed() > rto
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct RetransmitRequest {
84 pub sequence: u64,
86 pub data: Vec<u8>,
88 pub retransmit_count: u32,
90}
91
92pub struct FlowController {
95 window_size: usize,
96 cwnd: f64,
97 ssthresh: f64,
98 in_slow_start: bool,
99 in_flight: Vec<InFlightPacket>,
100 acked: BTreeSet<u64>,
101 rto: Duration,
102 srtt: Option<Duration>,
103 rttvar: Option<Duration>,
104 total_sent: u64,
105 total_acked: u64,
106 total_lost: u64,
107 total_retransmits: u64,
108}
109
110impl FlowController {
111 pub fn new(window_size: usize) -> Self {
113 debug!(window_size, "FlowController created");
114 FlowController {
115 window_size,
116 cwnd: 1.0,
117 ssthresh: window_size as f64 / 2.0,
118 in_slow_start: true,
119 in_flight: Vec::new(),
120 acked: BTreeSet::new(),
121 rto: Duration::from_millis(DEFAULT_RTO_MS),
122 srtt: None,
123 rttvar: None,
124 total_sent: 0,
125 total_acked: 0,
126 total_lost: 0,
127 total_retransmits: 0,
128 }
129 }
130
131 pub fn with_rto(window_size: usize, rto_ms: u64) -> Self {
133 let mut fc = Self::new(window_size);
134 fc.rto = Duration::from_millis(rto_ms);
135 fc
136 }
137
138 pub fn can_send(&self) -> bool {
142 self.in_flight.len() < self.effective_window()
143 }
144
145 pub fn available_slots(&self) -> usize {
147 self.effective_window().saturating_sub(self.in_flight.len())
148 }
149
150 pub fn window_size(&self) -> usize {
152 self.window_size
153 }
154
155 pub fn cwnd(&self) -> f64 {
157 self.cwnd
158 }
159
160 pub fn effective_window(&self) -> usize {
162 (self.cwnd as usize).min(self.window_size).max(1)
163 }
164
165 pub fn in_slow_start(&self) -> bool {
167 self.in_slow_start
168 }
169
170 pub fn set_window_size(&mut self, size: usize) {
172 debug!(old = self.window_size, new = size, "Window size updated");
173 self.window_size = size;
174 }
175
176 pub fn in_flight_count(&self) -> usize {
178 self.in_flight.len()
179 }
180
181 pub fn oldest_unacked_sequence(&self) -> Option<u64> {
183 self.in_flight.first().map(|p| p.sequence)
184 }
185
186 pub fn on_send(&mut self, sequence: u64, data: Vec<u8>) -> bool {
192 if !self.can_send() {
193 warn!(
194 sequence,
195 in_flight = self.in_flight.len(),
196 cwnd = self.cwnd,
197 "on_send() called but window is full"
198 );
199 return false;
200 }
201 self.in_flight.push(InFlightPacket::new(sequence, data));
202 self.total_sent += 1;
203 debug!(
204 sequence,
205 in_flight = self.in_flight.len(),
206 cwnd = self.cwnd,
207 effective_window = self.effective_window(),
208 "Packet sent"
209 );
210 true
211 }
212
213 pub fn on_ack(&mut self, sequence: u64) -> bool {
218 if let Some(pos) = self.in_flight.iter().position(|p| p.sequence == sequence) {
219 let packet = self.in_flight.remove(pos);
220 let rtt = packet.sent_at.elapsed();
221 self.update_rtt(rtt);
222 self.acked.insert(sequence);
223 self.total_acked += 1;
224 self.on_ack_cwnd();
225 debug!(
226 sequence,
227 rtt_ms = rtt.as_millis(),
228 in_flight = self.in_flight.len(),
229 cwnd = self.cwnd,
230 in_slow_start = self.in_slow_start,
231 "Packet acked"
232 );
233 true
234 } else {
235 warn!(sequence, "on_ack() for unknown or duplicate sequence");
236 false
237 }
238 }
239
240 pub fn timed_out_packets(&mut self) -> Vec<RetransmitRequest> {
245 let rto = self.rto;
246 let mut requests = Vec::new();
247 let mut had_loss = false;
248
249 for packet in self.in_flight.iter_mut() {
250 if packet.is_timed_out(rto) {
251 warn!(
252 sequence = packet.sequence,
253 retransmit_count = packet.retransmit_count,
254 rto_ms = rto.as_millis(),
255 "Packet timed out — queuing retransmission"
256 );
257 requests.push(RetransmitRequest {
258 sequence: packet.sequence,
259 data: packet.data.clone(),
260 retransmit_count: packet.retransmit_count,
261 });
262 packet.retransmit_count += 1;
263 packet.sent_at = Instant::now();
264 self.total_lost += 1;
265 self.total_retransmits += 1;
266 had_loss = true;
267 }
268 }
269
270 if had_loss {
271 self.on_loss_cwnd();
272 }
273
274 requests
275 }
276
277 fn on_ack_cwnd(&mut self) {
280 if self.in_slow_start {
281 self.cwnd += AIMD_INCREASE_STEP;
282 if self.cwnd >= self.ssthresh {
283 self.in_slow_start = false;
284 info!(cwnd = self.cwnd, ssthresh = self.ssthresh, "Exiting slow start");
285 }
286 } else {
287 self.cwnd += AIMD_INCREASE_STEP / self.cwnd;
288 }
289 self.cwnd = self.cwnd.min(self.window_size as f64);
290 debug!(cwnd = self.cwnd, "AIMD: cwnd increased");
291 }
292
293 fn on_loss_cwnd(&mut self) {
294 self.ssthresh = (self.cwnd * AIMD_DECREASE_FACTOR).max(CWND_MIN);
295 self.cwnd = CWND_MIN;
296 self.in_slow_start = true;
297 self.rto = (self.rto * 2).min(Duration::from_secs(60));
298 warn!(
299 cwnd = self.cwnd,
300 ssthresh = self.ssthresh,
301 rto_ms = self.rto.as_millis(),
302 "AIMD: multiplicative decrease on loss"
303 );
304 }
305
306 fn update_rtt(&mut self, rtt: Duration) {
309 match (self.srtt, self.rttvar) {
310 (None, None) => {
311 self.srtt = Some(rtt);
312 self.rttvar = Some(rtt / 2);
313 }
314 (Some(srtt), Some(rttvar)) => {
315 let rtt_ns = rtt.as_nanos() as i128;
316 let srtt_ns = srtt.as_nanos() as i128;
317 let rttvar_ns = rttvar.as_nanos() as i128;
318 let new_rttvar = (rttvar_ns * 3 / 4 + (srtt_ns - rtt_ns).abs() / 4).max(0) as u64;
319 let new_srtt = (srtt_ns * 7 / 8 + rtt_ns / 8).max(1) as u64;
320 self.rttvar = Some(Duration::from_nanos(new_rttvar));
321 self.srtt = Some(Duration::from_nanos(new_srtt));
322 let rto_ns = new_srtt + (new_rttvar * 4).max(1_000_000);
323 self.rto = Duration::from_nanos(rto_ns)
324 .max(Duration::from_millis(50))
325 .min(Duration::from_secs(60));
326 }
327 _ => {}
328 }
329 }
330
331 pub fn srtt(&self) -> Option<Duration> {
335 self.srtt
336 }
337
338 pub fn rttvar(&self) -> Option<Duration> {
340 self.rttvar
341 }
342
343 pub fn rto(&self) -> Duration {
345 self.rto
346 }
347
348 pub fn total_sent(&self) -> u64 {
350 self.total_sent
351 }
352
353 pub fn total_acked(&self) -> u64 {
355 self.total_acked
356 }
357
358 pub fn total_lost(&self) -> u64 {
360 self.total_lost
361 }
362
363 pub fn total_retransmits(&self) -> u64 {
365 self.total_retransmits
366 }
367
368 pub fn loss_rate(&self) -> f64 {
370 if self.total_sent == 0 { return 0.0; }
371 self.total_lost as f64 / self.total_sent as f64
372 }
373
374 pub fn is_acked(&self, sequence: u64) -> bool {
376 self.acked.contains(&sequence)
377 }
378
379 pub fn reset(&mut self) {
381 debug!("FlowController reset");
382 let window_size = self.window_size;
383 *self = Self::new(window_size);
384 }
385}
386
387impl Default for FlowController {
388 fn default() -> Self {
389 Self::new(64)
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_new() {
399 let fc = FlowController::new(4);
400 assert_eq!(fc.window_size(), 4);
401 assert_eq!(fc.in_flight_count(), 0);
402 assert!(fc.can_send());
403 assert_eq!(fc.available_slots(), 1);
404 assert!(fc.in_slow_start());
405 }
406
407 #[test]
408 fn test_window_full() {
409 let mut fc = FlowController::new(4);
410 assert!(fc.on_send(0, vec![0]));
411 assert!(!fc.can_send());
412 assert_eq!(fc.in_flight_count(), 1);
413 }
414
415 #[test]
416 fn test_ack_opens_window_and_grows_cwnd() {
417 let mut fc = FlowController::new(4);
418 assert!(fc.on_send(0, vec![0]));
419 assert!(!fc.can_send());
420 let cwnd_before = fc.cwnd();
421 fc.on_ack(0);
422 assert!(fc.cwnd() > cwnd_before);
423 assert!(fc.can_send());
424 }
425
426 #[test]
427 fn test_ack_unknown_sequence() {
428 let mut fc = FlowController::new(4);
429 fc.on_send(0, vec![0]);
430 assert!(!fc.on_ack(99));
431 assert_eq!(fc.in_flight_count(), 1);
432 }
433
434 #[test]
435 fn test_is_acked() {
436 let mut fc = FlowController::new(4);
437 fc.on_send(0, vec![0]);
438 assert!(!fc.is_acked(0));
439 fc.on_ack(0);
440 assert!(fc.is_acked(0));
441 }
442
443 #[test]
444 fn test_stats() {
445 let mut fc = FlowController::new(10);
446 for i in 0..5 {
447 if fc.can_send() {
448 fc.on_send(i, vec![0]);
449 fc.on_ack(i);
450 }
451 }
452 assert_eq!(fc.total_acked(), 5);
453 }
454
455 #[test]
456 fn test_loss_rate_zero() {
457 let fc = FlowController::new(4);
458 assert_eq!(fc.loss_rate(), 0.0);
459 }
460
461 #[test]
462 fn test_set_window_size() {
463 let mut fc = FlowController::new(4);
464 fc.set_window_size(8);
465 assert_eq!(fc.window_size(), 8);
466 }
467
468 #[test]
469 fn test_reset() {
470 let mut fc = FlowController::new(4);
471 fc.on_send(0, vec![0]);
472 fc.on_ack(0);
473 fc.reset();
474 assert_eq!(fc.in_flight_count(), 0);
475 assert_eq!(fc.total_sent(), 0);
476 assert_eq!(fc.total_acked(), 0);
477 assert!(fc.srtt().is_none());
478 assert!(fc.in_slow_start());
479 assert_eq!(fc.cwnd(), 1.0);
480 }
481
482 #[test]
483 fn test_timed_out_packets_returns_retransmit_requests() {
484 let mut fc = FlowController::with_rto(4, 1);
485 fc.on_send(0, b"hello".to_vec());
486 std::thread::sleep(Duration::from_millis(5));
487 let requests = fc.timed_out_packets();
488 assert_eq!(requests.len(), 1);
489 assert_eq!(requests[0].sequence, 0);
490 assert_eq!(requests[0].data, b"hello");
491 assert_eq!(requests[0].retransmit_count, 0);
492 assert_eq!(fc.total_lost(), 1);
493 assert_eq!(fc.total_retransmits(), 1);
494 }
495
496 #[test]
497 fn test_aimd_multiplicative_decrease_on_loss() {
498 let mut fc = FlowController::with_rto(4, 1);
499 fc.on_send(0, vec![0]);
500 std::thread::sleep(Duration::from_millis(5));
501 let requests = fc.timed_out_packets();
502 assert!(!requests.is_empty());
503 assert_eq!(fc.cwnd(), 1.0);
504 assert!(fc.in_slow_start());
505 assert_eq!(fc.total_lost(), 1);
506 }
507
508 #[test]
509 fn test_slow_start_exits_at_ssthresh() {
510 let mut fc = FlowController::new(64);
511 let ssthresh = fc.ssthresh;
512 let mut i = 0u64;
513 loop {
514 if fc.can_send() {
515 fc.on_send(i, vec![0]);
516 fc.on_ack(i);
517 i += 1;
518 }
519 if !fc.in_slow_start() { break; }
520 if i > 1000 { break; }
521 }
522 assert!(!fc.in_slow_start());
523 assert!(fc.cwnd() >= ssthresh);
524 }
525
526 #[test]
527 fn test_srtt_updated_on_ack() {
528 let mut fc = FlowController::new(4);
529 fc.on_send(0, vec![0]);
530 assert!(fc.srtt().is_none());
531 fc.on_ack(0);
532 assert!(fc.srtt().is_some());
533 assert!(fc.rttvar().is_some());
534 }
535
536 #[test]
537 fn test_default() {
538 let fc = FlowController::default();
539 assert_eq!(fc.window_size(), 64);
540 }
541
542 #[test]
543 fn test_on_send_full_window_returns_false() {
544 let mut fc = FlowController::new(4);
545 assert!(fc.on_send(0, vec![0]));
546 assert!(!fc.on_send(1, vec![0]));
547 }
548
549 #[test]
550 fn test_multiple_acks_grow_cwnd() {
551 let mut fc = FlowController::new(64);
552 let initial_cwnd = fc.cwnd();
553 for i in 0..10u64 {
554 if fc.can_send() {
555 fc.on_send(i, vec![0]);
556 fc.on_ack(i);
557 }
558 }
559 assert!(fc.cwnd() > initial_cwnd);
560 assert_eq!(fc.total_acked(), 10);
561 }
562
563 #[test]
564 fn test_oldest_unacked_sequence() {
565 let mut fc = FlowController::new(4);
566 assert!(fc.oldest_unacked_sequence().is_none());
567 fc.on_send(5, vec![0]);
568 assert_eq!(fc.oldest_unacked_sequence(), Some(5));
569 }
570
571 #[test]
572 fn test_effective_window_bounded_by_cwnd_and_max() {
573 let fc = FlowController::new(4);
574 assert_eq!(fc.effective_window(), 1);
575 }
576
577 #[test]
578 fn test_rto_doubles_on_loss() {
579 let mut fc = FlowController::with_rto(4, 1);
580 let rto_before = fc.rto();
581 fc.on_send(0, vec![0]);
582 std::thread::sleep(Duration::from_millis(5));
583 fc.timed_out_packets();
584 assert!(fc.rto() > rto_before);
585 }
586
587 #[test]
588 fn test_total_retransmits() {
589 let mut fc = FlowController::with_rto(4, 1);
590 fc.on_send(0, vec![0]);
591 std::thread::sleep(Duration::from_millis(5));
592 fc.timed_out_packets();
593 assert_eq!(fc.total_retransmits(), 1);
594 std::thread::sleep(Duration::from_millis(10));
595 fc.timed_out_packets();
596 assert_eq!(fc.total_retransmits(), 2);
597 }
598}