1use crate::{
5 event,
6 event::{
7 api::SocketAddress,
8 builder::{BbrState, SlowStartExitCause},
9 IntoEvent,
10 },
11 inet, path,
12 path::Config,
13 random,
14 recovery::{
15 bandwidth::{Bandwidth, RateSample},
16 RttEstimator,
17 },
18 time::Timestamp,
19};
20use core::fmt::Debug;
21use num_rational::Ratio;
22use num_traits::ToPrimitive;
23
24pub trait Endpoint: 'static + Debug + Send {
25 type CongestionController: CongestionController;
26
27 fn new_congestion_controller(&mut self, path_info: PathInfo) -> Self::CongestionController;
28}
29
30#[derive(Debug)]
31#[non_exhaustive]
32pub struct PathInfo<'a> {
33 pub remote_address: SocketAddress<'a>,
34 pub application_protocol: Option<&'a [u8]>,
35 pub max_datagram_size: u16,
36}
37
38impl<'a> PathInfo<'a> {
39 #[allow(deprecated)]
40 pub fn new(mtu_config: &Config, remote_address: &'a inet::SocketAddress) -> Self {
41 Self {
42 remote_address: remote_address.into_event(),
43 application_protocol: None,
44 max_datagram_size: mtu_config.initial_mtu().max_datagram_size(remote_address),
45 }
46 }
47}
48
49pub trait Publisher {
50 fn on_slow_start_exited(&mut self, cause: SlowStartExitCause, congestion_window: u32);
52 fn on_delivery_rate_sampled(&mut self, rate_sample: RateSample);
54 fn on_pacing_rate_updated(
56 &mut self,
57 pacing_rate: Bandwidth,
58 burst_size: u32,
59 pacing_gain: Ratio<u64>,
60 );
61 fn on_bbr_state_changed(&mut self, state: BbrState);
63}
64
65pub struct PathPublisher<'a, Pub: event::ConnectionPublisher> {
68 publisher: &'a mut Pub,
69 path_id: path::Id,
70}
71
72impl<'a, Pub: event::ConnectionPublisher> PathPublisher<'a, Pub> {
73 pub fn new(publisher: &'a mut Pub, path_id: path::Id) -> PathPublisher<'a, Pub> {
75 Self { publisher, path_id }
76 }
77}
78
79impl<Pub: event::ConnectionPublisher> Publisher for PathPublisher<'_, Pub> {
80 #[inline]
81 fn on_slow_start_exited(&mut self, cause: SlowStartExitCause, congestion_window: u32) {
82 self.publisher
83 .on_slow_start_exited(event::builder::SlowStartExited {
84 path_id: self.path_id.into_event(),
85 cause,
86 congestion_window,
87 });
88 }
89
90 #[inline]
91 fn on_delivery_rate_sampled(&mut self, rate_sample: RateSample) {
92 self.publisher
93 .on_delivery_rate_sampled(event::builder::DeliveryRateSampled {
94 path_id: self.path_id.into_event(),
95 rate_sample: rate_sample.into_event(),
96 })
97 }
98
99 #[inline]
100 fn on_pacing_rate_updated(
101 &mut self,
102 pacing_rate: Bandwidth,
103 burst_size: u32,
104 pacing_gain: Ratio<u64>,
105 ) {
106 self.publisher
107 .on_pacing_rate_updated(event::builder::PacingRateUpdated {
108 path_id: self.path_id.into_event(),
109 bytes_per_second: pacing_rate.as_bytes_per_second(),
110 burst_size,
111 pacing_gain: pacing_gain
112 .to_f32()
113 .expect("pacing gain should be representable as f32"),
114 })
115 }
116
117 #[inline]
118 fn on_bbr_state_changed(&mut self, state: BbrState) {
119 self.publisher
120 .on_bbr_state_changed(event::builder::BbrStateChanged {
121 path_id: self.path_id.into_event(),
122 state,
123 })
124 }
125}
126
127pub trait CongestionController: 'static + Clone + Send + Debug + private::Sealed {
132 type PacketInfo: Copy + Send + Sized + Debug;
135
136 fn congestion_window(&self) -> u32;
138
139 fn bytes_in_flight(&self) -> u32;
141
142 fn is_congestion_limited(&self) -> bool;
146
147 fn requires_fast_retransmission(&self) -> bool;
151
152 fn on_packet_sent<Pub: Publisher>(
165 &mut self,
166 time_sent: Timestamp,
167 sent_bytes: usize,
168 app_limited: Option<bool>,
169 rtt_estimator: &RttEstimator,
170 publisher: &mut Pub,
171 ) -> Self::PacketInfo;
172
173 fn on_rtt_update<Pub: Publisher>(
176 &mut self,
177 time_sent: Timestamp,
178 now: Timestamp,
179 rtt_estimator: &RttEstimator,
180 publisher: &mut Pub,
181 );
182
183 fn on_ack<Pub: Publisher>(
190 &mut self,
191 newest_acked_time_sent: Timestamp,
192 bytes_acknowledged: usize,
193 newest_acked_packet_info: Self::PacketInfo,
194 rtt_estimator: &RttEstimator,
195 random_generator: &mut dyn random::Generator,
196 ack_receive_time: Timestamp,
197 publisher: &mut Pub,
198 );
199
200 fn on_packet_lost<Pub: Publisher>(
206 &mut self,
207 lost_bytes: u32,
208 packet_info: Self::PacketInfo,
209 persistent_congestion: bool,
210 new_loss_burst: bool,
211 random_generator: &mut dyn random::Generator,
212 timestamp: Timestamp,
213 publisher: &mut Pub,
214 );
215
216 fn on_explicit_congestion<Pub: Publisher>(
220 &mut self,
221 ce_count: u64,
222 event_time: Timestamp,
223 publisher: &mut Pub,
224 );
225
226 fn on_mtu_update<Pub: Publisher>(&mut self, max_data_size: u16, publisher: &mut Pub);
228
229 fn on_packet_discarded<Pub: Publisher>(&mut self, bytes_sent: usize, publisher: &mut Pub);
231
232 fn earliest_departure_time(&self) -> Option<Timestamp>;
236
237 fn send_quantum(&self) -> Option<usize> {
244 None
245 }
246}
247
248mod private {
251 use cfg_if::cfg_if;
252
253 pub trait Sealed {}
254
255 cfg_if!(
256 if #[cfg(any(test, feature = "unstable-congestion-controller", feature = "testing"))] {
257 impl<T: crate::recovery::CongestionController> Sealed for T {}
260 } else {
261 impl Sealed for crate::recovery::CubicCongestionController {}
263 impl Sealed for crate::recovery::bbr::BbrCongestionController {}
264 }
265 );
266}
267
268#[cfg(any(test, feature = "testing"))]
269pub mod testing {
270 use super::*;
271 use crate::recovery::RttEstimator;
272
273 pub mod unlimited {
274 use super::*;
275
276 #[derive(Debug, Default)]
277 pub struct Endpoint {}
278
279 impl super::Endpoint for Endpoint {
280 type CongestionController = CongestionController;
281
282 fn new_congestion_controller(
283 &mut self,
284 _path_info: super::PathInfo,
285 ) -> Self::CongestionController {
286 CongestionController::default()
287 }
288 }
289
290 #[derive(Clone, Copy, Debug, Default, PartialEq)]
291 pub struct CongestionController {}
292
293 #[derive(Clone, Copy, Debug, Default)]
295 pub struct PacketInfo(());
296
297 impl super::CongestionController for CongestionController {
298 type PacketInfo = PacketInfo;
299
300 fn congestion_window(&self) -> u32 {
301 u32::MAX
302 }
303
304 fn bytes_in_flight(&self) -> u32 {
305 0
306 }
307
308 fn is_congestion_limited(&self) -> bool {
309 false
310 }
311
312 fn requires_fast_retransmission(&self) -> bool {
313 false
314 }
315
316 fn on_packet_sent<Pub: Publisher>(
317 &mut self,
318 _time_sent: Timestamp,
319 _bytes_sent: usize,
320 _app_limited: Option<bool>,
321 _rtt_estimator: &RttEstimator,
322 _publisher: &mut Pub,
323 ) -> PacketInfo {
324 PacketInfo(())
325 }
326
327 fn on_rtt_update<Pub: Publisher>(
328 &mut self,
329 _time_sent: Timestamp,
330 _now: Timestamp,
331 _rtt_estimator: &RttEstimator,
332 _publisher: &mut Pub,
333 ) {
334 }
335
336 fn on_ack<Pub: Publisher>(
337 &mut self,
338 _newest_acked_time_sent: Timestamp,
339 _sent_bytes: usize,
340 _newest_acked_packet_info: Self::PacketInfo,
341 _rtt_estimator: &RttEstimator,
342 _random_generator: &mut dyn random::Generator,
343 _ack_receive_time: Timestamp,
344 _publisher: &mut Pub,
345 ) {
346 }
347
348 fn on_packet_lost<Pub: Publisher>(
349 &mut self,
350 _lost_bytes: u32,
351 _packet_info: Self::PacketInfo,
352 _persistent_congestion: bool,
353 _new_loss_burst: bool,
354 _random_generator: &mut dyn random::Generator,
355 _timestamp: Timestamp,
356 _publisher: &mut Pub,
357 ) {
358 }
359
360 fn on_explicit_congestion<Pub: Publisher>(
361 &mut self,
362 _ce_count: u64,
363 _event_time: Timestamp,
364 _publisher: &mut Pub,
365 ) {
366 }
367
368 fn on_mtu_update<Pub: Publisher>(&mut self, _max_data_size: u16, _publisher: &mut Pub) {
369 }
370
371 fn on_packet_discarded<Pub: Publisher>(
372 &mut self,
373 _bytes_sent: usize,
374 _publisher: &mut Pub,
375 ) {
376 }
377
378 fn earliest_departure_time(&self) -> Option<Timestamp> {
379 None
380 }
381 }
382 }
383
384 pub mod mock {
385 use super::*;
386 use crate::path::RemoteAddress;
387
388 #[derive(Debug, Default)]
389 pub struct Endpoint {}
390
391 impl super::Endpoint for Endpoint {
392 type CongestionController = CongestionController;
393
394 fn new_congestion_controller(
395 &mut self,
396 path_info: super::PathInfo,
397 ) -> Self::CongestionController {
398 CongestionController::new(path_info.remote_address.into())
399 }
400 }
401
402 #[derive(Clone, Copy, Debug, Default)]
403 pub struct PacketInfo {
404 remote_address: RemoteAddress,
405 }
406
407 #[derive(Clone, Copy, Debug, PartialEq)]
408 pub struct CongestionController {
409 pub bytes_in_flight: u32,
410 pub lost_bytes: u32,
411 pub persistent_congestion: Option<bool>,
412 pub on_packets_lost: u32,
413 pub on_rtt_update: u32,
414 pub on_packet_ack: u32,
415 pub on_mtu_update: u32,
416 pub congestion_window: u32,
417 pub congestion_events: u32,
418 pub requires_fast_retransmission: bool,
419 pub loss_bursts: u32,
420 pub app_limited: Option<bool>,
421 pub slow_start: bool,
422 pub remote_address: RemoteAddress,
423 }
424
425 impl Default for CongestionController {
426 fn default() -> Self {
427 Self {
428 bytes_in_flight: 0,
429 lost_bytes: 0,
430 persistent_congestion: None,
431 on_packets_lost: 0,
432 on_rtt_update: 0,
433 on_packet_ack: 0,
434 on_mtu_update: 0,
435 congestion_window: 1500 * 10,
436 congestion_events: 0,
437 requires_fast_retransmission: false,
438 loss_bursts: 0,
439 app_limited: None,
440 slow_start: true,
441 remote_address: RemoteAddress::default(),
442 }
443 }
444 }
445
446 impl CongestionController {
447 pub fn new(remote_address: RemoteAddress) -> Self {
448 Self {
449 remote_address,
450 ..Default::default()
451 }
452 }
453 }
454
455 impl super::CongestionController for CongestionController {
456 type PacketInfo = PacketInfo;
457
458 fn congestion_window(&self) -> u32 {
459 self.congestion_window
460 }
461
462 fn bytes_in_flight(&self) -> u32 {
463 self.bytes_in_flight
464 }
465
466 fn is_congestion_limited(&self) -> bool {
467 self.requires_fast_retransmission || self.bytes_in_flight >= self.congestion_window
468 }
469
470 fn requires_fast_retransmission(&self) -> bool {
471 self.requires_fast_retransmission
472 }
473
474 fn on_packet_sent<Pub: Publisher>(
475 &mut self,
476 _time_sent: Timestamp,
477 bytes_sent: usize,
478 app_limited: Option<bool>,
479 _rtt_estimator: &RttEstimator,
480 _publisher: &mut Pub,
481 ) -> PacketInfo {
482 self.bytes_in_flight += bytes_sent as u32;
483 self.requires_fast_retransmission = false;
484 self.app_limited = app_limited;
485 PacketInfo {
486 remote_address: self.remote_address,
487 }
488 }
489
490 fn on_rtt_update<Pub: Publisher>(
491 &mut self,
492 _time_sent: Timestamp,
493 _now: Timestamp,
494 _rtt_estimator: &RttEstimator,
495 _publisher: &mut Pub,
496 ) {
497 self.on_rtt_update += 1
498 }
499
500 fn on_ack<Pub: Publisher>(
501 &mut self,
502 _newest_acked_time_sent: Timestamp,
503 _sent_bytes: usize,
504 newest_acked_packet_info: Self::PacketInfo,
505 _rtt_estimator: &RttEstimator,
506 _random_generator: &mut dyn random::Generator,
507 _ack_receive_time: Timestamp,
508 _publisher: &mut Pub,
509 ) {
510 assert_eq!(self.remote_address, newest_acked_packet_info.remote_address);
511
512 self.on_packet_ack += 1;
513 }
514
515 fn on_packet_lost<Pub: Publisher>(
516 &mut self,
517 lost_bytes: u32,
518 packet_info: Self::PacketInfo,
519 persistent_congestion: bool,
520 new_loss_burst: bool,
521 _random_generator: &mut dyn random::Generator,
522 _timestamp: Timestamp,
523 _publisher: &mut Pub,
524 ) {
525 assert_eq!(self.remote_address, packet_info.remote_address);
526
527 self.bytes_in_flight = self.bytes_in_flight.saturating_sub(lost_bytes);
528 self.lost_bytes += lost_bytes;
529 self.persistent_congestion = Some(persistent_congestion);
530 self.on_packets_lost += 1;
531 self.requires_fast_retransmission = true;
532
533 if new_loss_burst {
534 self.loss_bursts += 1;
535 }
536 }
537
538 fn on_explicit_congestion<Pub: Publisher>(
539 &mut self,
540 _ce_count: u64,
541 _event_time: Timestamp,
542 _publisher: &mut Pub,
543 ) {
544 self.congestion_events += 1;
545 self.slow_start = false;
546 }
547
548 fn on_mtu_update<Pub: Publisher>(&mut self, _max_data_size: u16, _publisher: &mut Pub) {
549 self.on_mtu_update += 1;
550 }
551
552 fn on_packet_discarded<Pub: Publisher>(
553 &mut self,
554 bytes_sent: usize,
555 _publisher: &mut Pub,
556 ) {
557 self.bytes_in_flight = self.bytes_in_flight.saturating_sub(bytes_sent as u32);
558 }
559
560 fn earliest_departure_time(&self) -> Option<Timestamp> {
561 None
562 }
563 }
564 }
565}
566
567#[cfg(test)]
568mod fuzz_target;