tokio_udt/
rate_control.rs

1use crate::flow::UdtFlow;
2use crate::seq_number::SeqNumber;
3use crate::socket::SYN_INTERVAL;
4use rand::Rng;
5use tokio::time::{Duration, Instant};
6
7#[derive(Debug)]
8pub struct RateControl {
9    pkt_send_period: Duration,
10    congestion_window_size: f64,
11    max_window_size: f64,
12    recv_rate: u32,
13    bandwidth: u32,
14    rtt: Duration,
15    mss: f64,
16
17    curr_snd_seq_number: SeqNumber,
18    rc_interval: Duration,
19    last_rate_increase: Instant,
20    slow_start: bool,
21    last_ack: SeqNumber,
22    loss: bool, // has lost happenened since last rate increase
23    last_dec_seq: SeqNumber,
24    last_dec_period: Duration,
25    nak_count: usize,
26    dec_random: usize,
27    avg_nak_num: usize,
28    dec_count: usize,
29
30    ack_period: Duration,
31    ack_pkt_interval: usize,
32}
33
34impl RateControl {
35    pub(crate) fn new() -> Self {
36        Self {
37            pkt_send_period: Duration::from_micros(1),
38            congestion_window_size: 16.0,
39            max_window_size: 16.0,
40            recv_rate: 0,
41            bandwidth: 0,
42            rtt: Duration::default(),
43            mss: 1500.0,
44
45            curr_snd_seq_number: SeqNumber::zero(),
46            rc_interval: SYN_INTERVAL,
47            last_rate_increase: Instant::now(),
48            slow_start: true,
49            last_ack: SeqNumber::zero(),
50            loss: false,
51            last_dec_seq: SeqNumber::zero() - 1,
52            last_dec_period: Duration::from_micros(1),
53            nak_count: 0,
54            avg_nak_num: 0,
55            dec_random: 1,
56            dec_count: 0,
57
58            ack_period: SYN_INTERVAL,
59            ack_pkt_interval: 0,
60        }
61    }
62
63    pub(crate) fn init(&mut self, mss: u32, flow: &UdtFlow, seq_number: SeqNumber) {
64        self.last_rate_increase = Instant::now();
65        self.mss = mss as f64;
66        self.max_window_size = flow.flow_window_size as f64;
67
68        self.slow_start = true;
69        self.loss = false;
70        self.curr_snd_seq_number = seq_number;
71        self.last_ack = seq_number;
72        self.last_dec_seq = seq_number - 1;
73
74        self.recv_rate = flow.peer_delivery_rate;
75        self.bandwidth = flow.peer_bandwidth;
76        self.rtt = flow.rtt;
77    }
78
79    pub fn get_pkt_send_period(&self) -> Duration {
80        self.pkt_send_period
81    }
82
83    pub fn get_congestion_window_size(&self) -> u32 {
84        self.congestion_window_size as u32
85    }
86
87    pub fn get_ack_pkt_interval(&self) -> usize {
88        self.ack_pkt_interval
89    }
90
91    pub fn get_ack_period(&self) -> Duration {
92        std::cmp::min(SYN_INTERVAL, self.ack_period)
93    }
94
95    pub fn set_rtt(&mut self, rtt: Duration) {
96        self.rtt = rtt;
97    }
98
99    pub fn set_rcv_rate(&mut self, pkt_per_sec: u32) {
100        self.recv_rate = pkt_per_sec;
101    }
102
103    pub fn set_bandwidth(&mut self, pkt_per_sec: u32) {
104        self.bandwidth = pkt_per_sec;
105    }
106
107    pub fn set_pkt_interval(&mut self, nb_pkts: usize) {
108        self.ack_pkt_interval = nb_pkts;
109    }
110
111    pub fn on_ack(&mut self, ack: SeqNumber) {
112        const MIN_INC: f64 = 0.01;
113
114        let now = Instant::now();
115        if (now - self.last_rate_increase) < self.rc_interval {
116            return;
117        }
118        self.last_rate_increase = now;
119
120        if self.slow_start {
121            self.congestion_window_size += (ack - self.last_ack) as f64;
122            self.last_ack = ack;
123            if self.congestion_window_size > self.max_window_size {
124                self.slow_start = false;
125                if self.recv_rate > 0 {
126                    self.pkt_send_period = Duration::from_secs(1) / self.recv_rate;
127                } else {
128                    self.pkt_send_period =
129                        (self.rtt + self.rc_interval).div_f64(self.congestion_window_size);
130                }
131            }
132        } else {
133            self.congestion_window_size =
134                self.recv_rate as f64 * (self.rtt + self.rc_interval).as_secs_f64() + 16.0
135        }
136
137        if self.slow_start {
138            return;
139        }
140
141        if self.loss {
142            self.loss = false;
143            return;
144        }
145
146        let mut b = self.bandwidth as f64 - 1.0 / self.pkt_send_period.as_secs_f64();
147        if (self.pkt_send_period > self.last_dec_period) && (self.bandwidth as f64 / 9.0 < b) {
148            b = self.bandwidth as f64 / 9.0;
149        }
150        let increase = if b <= 0.0 {
151            MIN_INC
152        } else {
153            let inc = 10.0_f64.powf((b * self.mss as f64 * 8.0).log10().ceil()) * 1.5e-6 / self.mss;
154            if inc < MIN_INC {
155                MIN_INC
156            } else {
157                inc
158            }
159        };
160        self.pkt_send_period = Duration::from_secs_f64(
161            (self.pkt_send_period.as_secs_f64() * self.rc_interval.as_secs_f64())
162                / (self.pkt_send_period.mul_f64(increase) + self.rc_interval).as_secs_f64(),
163        );
164    }
165
166    pub fn on_loss(&mut self, loss_seq: SeqNumber) {
167        if self.slow_start {
168            self.slow_start = false;
169            if self.recv_rate > 0 {
170                self.pkt_send_period = Duration::from_secs(1) / self.recv_rate;
171                return;
172            }
173            self.pkt_send_period =
174                (self.rtt + self.rc_interval).div_f64(self.congestion_window_size);
175        }
176
177        self.loss = true;
178        if (loss_seq - self.last_dec_seq) > 0 {
179            self.last_dec_period = self.pkt_send_period;
180            self.pkt_send_period = self.pkt_send_period.mul_f64(1.125);
181            self.avg_nak_num =
182                (self.avg_nak_num as f64 * 0.875 + self.nak_count as f64 * 0.125).ceil() as usize;
183            self.nak_count = 1;
184            self.dec_count = 1;
185            self.last_dec_seq = self.curr_snd_seq_number;
186
187            self.dec_random = if self.avg_nak_num == 0 {
188                1
189            } else {
190                rand::thread_rng().gen_range(1..=self.avg_nak_num)
191            };
192        } else {
193            self.dec_count += 1;
194            if self.dec_count <= 5 {
195                self.nak_count += 1;
196                if self.nak_count % self.dec_random == 0 {
197                    self.pkt_send_period = self.pkt_send_period.mul_f64(1.125);
198                    self.last_dec_seq = self.curr_snd_seq_number;
199                }
200            }
201        }
202    }
203
204    pub fn set_curr_snd_seq_number(&mut self, seq: SeqNumber) {
205        self.curr_snd_seq_number = seq;
206    }
207
208    pub fn on_timeout(&mut self) {
209        if self.slow_start {
210            self.slow_start = false;
211            if self.recv_rate > 0 {
212                self.pkt_send_period = Duration::from_secs(1) / self.recv_rate;
213            } else {
214                self.pkt_send_period =
215                    (self.rtt + self.rc_interval).div_f64(self.congestion_window_size);
216            }
217        }
218    }
219}