tokio_udt/
rate_control.rs1use crate::flow::UdtFlow;
2use crate::seq_number::SeqNumber;
3use crate::socket::SYN_INTERVAL;
4use rand::Rng;
5use tokio::time::{Duration, Instant};
6
7#[derive(Debug)]
8pub struct RateControl {
9 pkt_send_period: Duration,
10 congestion_window_size: f64,
11 max_window_size: f64,
12 recv_rate: u32,
13 bandwidth: u32,
14 rtt: Duration,
15 mss: f64,
16
17 curr_snd_seq_number: SeqNumber,
18 rc_interval: Duration,
19 last_rate_increase: Instant,
20 slow_start: bool,
21 last_ack: SeqNumber,
22 loss: bool, last_dec_seq: SeqNumber,
24 last_dec_period: Duration,
25 nak_count: usize,
26 dec_random: usize,
27 avg_nak_num: usize,
28 dec_count: usize,
29
30 ack_period: Duration,
31 ack_pkt_interval: usize,
32}
33
34impl RateControl {
35 pub(crate) fn new() -> Self {
36 Self {
37 pkt_send_period: Duration::from_micros(1),
38 congestion_window_size: 16.0,
39 max_window_size: 16.0,
40 recv_rate: 0,
41 bandwidth: 0,
42 rtt: Duration::default(),
43 mss: 1500.0,
44
45 curr_snd_seq_number: SeqNumber::zero(),
46 rc_interval: SYN_INTERVAL,
47 last_rate_increase: Instant::now(),
48 slow_start: true,
49 last_ack: SeqNumber::zero(),
50 loss: false,
51 last_dec_seq: SeqNumber::zero() - 1,
52 last_dec_period: Duration::from_micros(1),
53 nak_count: 0,
54 avg_nak_num: 0,
55 dec_random: 1,
56 dec_count: 0,
57
58 ack_period: SYN_INTERVAL,
59 ack_pkt_interval: 0,
60 }
61 }
62
63 pub(crate) fn init(&mut self, mss: u32, flow: &UdtFlow, seq_number: SeqNumber) {
64 self.last_rate_increase = Instant::now();
65 self.mss = mss as f64;
66 self.max_window_size = flow.flow_window_size as f64;
67
68 self.slow_start = true;
69 self.loss = false;
70 self.curr_snd_seq_number = seq_number;
71 self.last_ack = seq_number;
72 self.last_dec_seq = seq_number - 1;
73
74 self.recv_rate = flow.peer_delivery_rate;
75 self.bandwidth = flow.peer_bandwidth;
76 self.rtt = flow.rtt;
77 }
78
79 pub fn get_pkt_send_period(&self) -> Duration {
80 self.pkt_send_period
81 }
82
83 pub fn get_congestion_window_size(&self) -> u32 {
84 self.congestion_window_size as u32
85 }
86
87 pub fn get_ack_pkt_interval(&self) -> usize {
88 self.ack_pkt_interval
89 }
90
91 pub fn get_ack_period(&self) -> Duration {
92 std::cmp::min(SYN_INTERVAL, self.ack_period)
93 }
94
95 pub fn set_rtt(&mut self, rtt: Duration) {
96 self.rtt = rtt;
97 }
98
99 pub fn set_rcv_rate(&mut self, pkt_per_sec: u32) {
100 self.recv_rate = pkt_per_sec;
101 }
102
103 pub fn set_bandwidth(&mut self, pkt_per_sec: u32) {
104 self.bandwidth = pkt_per_sec;
105 }
106
107 pub fn set_pkt_interval(&mut self, nb_pkts: usize) {
108 self.ack_pkt_interval = nb_pkts;
109 }
110
111 pub fn on_ack(&mut self, ack: SeqNumber) {
112 const MIN_INC: f64 = 0.01;
113
114 let now = Instant::now();
115 if (now - self.last_rate_increase) < self.rc_interval {
116 return;
117 }
118 self.last_rate_increase = now;
119
120 if self.slow_start {
121 self.congestion_window_size += (ack - self.last_ack) as f64;
122 self.last_ack = ack;
123 if self.congestion_window_size > self.max_window_size {
124 self.slow_start = false;
125 if self.recv_rate > 0 {
126 self.pkt_send_period = Duration::from_secs(1) / self.recv_rate;
127 } else {
128 self.pkt_send_period =
129 (self.rtt + self.rc_interval).div_f64(self.congestion_window_size);
130 }
131 }
132 } else {
133 self.congestion_window_size =
134 self.recv_rate as f64 * (self.rtt + self.rc_interval).as_secs_f64() + 16.0
135 }
136
137 if self.slow_start {
138 return;
139 }
140
141 if self.loss {
142 self.loss = false;
143 return;
144 }
145
146 let mut b = self.bandwidth as f64 - 1.0 / self.pkt_send_period.as_secs_f64();
147 if (self.pkt_send_period > self.last_dec_period) && (self.bandwidth as f64 / 9.0 < b) {
148 b = self.bandwidth as f64 / 9.0;
149 }
150 let increase = if b <= 0.0 {
151 MIN_INC
152 } else {
153 let inc = 10.0_f64.powf((b * self.mss as f64 * 8.0).log10().ceil()) * 1.5e-6 / self.mss;
154 if inc < MIN_INC {
155 MIN_INC
156 } else {
157 inc
158 }
159 };
160 self.pkt_send_period = Duration::from_secs_f64(
161 (self.pkt_send_period.as_secs_f64() * self.rc_interval.as_secs_f64())
162 / (self.pkt_send_period.mul_f64(increase) + self.rc_interval).as_secs_f64(),
163 );
164 }
165
166 pub fn on_loss(&mut self, loss_seq: SeqNumber) {
167 if self.slow_start {
168 self.slow_start = false;
169 if self.recv_rate > 0 {
170 self.pkt_send_period = Duration::from_secs(1) / self.recv_rate;
171 return;
172 }
173 self.pkt_send_period =
174 (self.rtt + self.rc_interval).div_f64(self.congestion_window_size);
175 }
176
177 self.loss = true;
178 if (loss_seq - self.last_dec_seq) > 0 {
179 self.last_dec_period = self.pkt_send_period;
180 self.pkt_send_period = self.pkt_send_period.mul_f64(1.125);
181 self.avg_nak_num =
182 (self.avg_nak_num as f64 * 0.875 + self.nak_count as f64 * 0.125).ceil() as usize;
183 self.nak_count = 1;
184 self.dec_count = 1;
185 self.last_dec_seq = self.curr_snd_seq_number;
186
187 self.dec_random = if self.avg_nak_num == 0 {
188 1
189 } else {
190 rand::thread_rng().gen_range(1..=self.avg_nak_num)
191 };
192 } else {
193 self.dec_count += 1;
194 if self.dec_count <= 5 {
195 self.nak_count += 1;
196 if self.nak_count % self.dec_random == 0 {
197 self.pkt_send_period = self.pkt_send_period.mul_f64(1.125);
198 self.last_dec_seq = self.curr_snd_seq_number;
199 }
200 }
201 }
202 }
203
204 pub fn set_curr_snd_seq_number(&mut self, seq: SeqNumber) {
205 self.curr_snd_seq_number = seq;
206 }
207
208 pub fn on_timeout(&mut self) {
209 if self.slow_start {
210 self.slow_start = false;
211 if self.recv_rate > 0 {
212 self.pkt_send_period = Duration::from_secs(1) / self.recv_rate;
213 } else {
214 self.pkt_send_period =
215 (self.rtt + self.rc_interval).div_f64(self.congestion_window_size);
216 }
217 }
218 }
219}