s2n_quic/provider/
endpoint_limits.rs1pub use s2n_quic_core::endpoint::{
7 limits::{ConnectionAttempt, Outcome},
8 Limiter,
9};
10use s2n_quic_core::{event::Timestamp, path::THROTTLED_PORTS_LEN};
11
12pub trait Provider: 'static {
13 type Limits: 'static + Limiter;
14 type Error: core::fmt::Display + Send + Sync;
15
16 fn start(self) -> Result<Self::Limits, Self::Error>;
18}
19
20use core::time::Duration;
21pub use default::Limits as Default;
22
23impl_provider_utils!();
24
25impl<T: 'static + Limiter> Provider for T {
26 type Limits = T;
27 type Error = core::convert::Infallible;
28
29 fn start(self) -> Result<Self::Limits, Self::Error> {
30 Ok(self)
31 }
32}
33
34const THROTTLED_PORT_LIMIT: usize = 10;
35const THROTTLE_FREQUENCY: Duration = Duration::from_secs(1);
36
37#[derive(Default, Debug, Clone, Copy)]
38struct BasicRateLimiter {
39 last_throttle_reset: Option<Timestamp>,
40 count: usize,
41}
42
43impl BasicRateLimiter {
44 fn should_throttle(
52 &mut self,
53 limit: usize,
54 throttle_frequency: Duration,
55 connection_attempt: &ConnectionAttempt,
56 ) -> bool {
57 self.count += 1;
58 let timestamp = connection_attempt.timestamp;
59
60 if self.count > limit {
61 match self.last_throttle_reset {
62 Some(last_throttle_reset)
65 if timestamp.saturating_duration_since(last_throttle_reset)
66 < throttle_frequency =>
67 {
68 return true;
69 }
70 _ => {
74 self.count = 0;
75 self.last_throttle_reset = Some(timestamp);
76 return false;
77 }
78 };
79 }
80
81 if self.last_throttle_reset.is_none() {
83 self.last_throttle_reset = Some(timestamp);
84 }
85
86 false
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::{BasicRateLimiter, THROTTLED_PORT_LIMIT, THROTTLE_FREQUENCY};
93 use core::time::Duration;
94 use s2n_quic_core::{
95 endpoint::limits::ConnectionAttempt,
96 event::IntoEvent,
97 inet::SocketAddress,
98 time::{testing::Clock as MockClock, Clock},
99 };
100
101 #[test]
102 fn first_throttle_reset() {
103 let remote_address = SocketAddress::default();
104 let mock_clock = MockClock::default();
105 let info =
106 ConnectionAttempt::new(0, 0, &remote_address, mock_clock.get_time().into_event());
107
108 let mut rate_limiter = BasicRateLimiter::default();
109 let very_long_freq = Duration::MAX;
113
114 for request in 0..(THROTTLED_PORT_LIMIT * 3) {
115 if request >= THROTTLED_PORT_LIMIT {
116 assert!(rate_limiter.should_throttle(THROTTLED_PORT_LIMIT, very_long_freq, &info));
117 } else {
118 assert!(!rate_limiter.should_throttle(THROTTLED_PORT_LIMIT, very_long_freq, &info));
119 }
120 }
121 }
122
123 #[test]
124 fn throttle_timer_reset() {
125 let remote_address = SocketAddress::default();
126 let mut mock_clock = MockClock::default();
127
128 let mut rate_limiter = BasicRateLimiter::default();
129 let short_freq = Duration::from_millis(10);
130 let sleep_longer_than_short_freq = Duration::from_millis(500);
131
132 for request in 0..(THROTTLED_PORT_LIMIT * 3) {
135 let info =
136 ConnectionAttempt::new(0, 0, &remote_address, mock_clock.get_time().into_event());
137 if request % THROTTLED_PORT_LIMIT == 0 {
138 mock_clock.inc_by(sleep_longer_than_short_freq)
139 }
140 assert!(!rate_limiter.should_throttle(THROTTLED_PORT_LIMIT, short_freq, &info));
141 }
142 }
143
144 #[test]
145 fn throttle_constants_changed() {
146 assert_eq!(THROTTLED_PORT_LIMIT, 10);
152 assert_eq!(THROTTLE_FREQUENCY, Duration::from_secs(1));
153 }
154}
155
156pub mod default {
157 use super::*;
160 use core::convert::Infallible;
161
162 #[derive(Default)]
181 pub struct Builder {
182 max_inflight_handshake_limit: Option<usize>,
183 }
184
185 impl Builder {
186 pub fn with_inflight_handshake_limit(mut self, limit: usize) -> Result<Self, Infallible> {
188 self.max_inflight_handshake_limit = Some(limit);
189 Ok(self)
190 }
191
192 pub fn build(self) -> Result<Limits, Infallible> {
194 Ok(Limits {
195 max_inflight_handshake_limit: self.max_inflight_handshake_limit,
196 rate_limiter: [BasicRateLimiter::default(); THROTTLED_PORTS_LEN],
197 })
198 }
199 }
200
201 #[derive(Clone, Copy, Debug)]
202 pub struct Limits {
203 max_inflight_handshake_limit: Option<usize>,
205 rate_limiter: [BasicRateLimiter; THROTTLED_PORTS_LEN],
206 }
207
208 impl Limits {
209 pub fn builder() -> Builder {
210 Builder::default()
211 }
212 }
213
214 impl super::Limiter for Limits {
216 fn on_connection_attempt(&mut self, info: &ConnectionAttempt) -> Outcome {
217 let remote_port = info.remote_address.port();
218 if s2n_quic_core::path::remote_port_blocked(remote_port) {
219 return Outcome::drop();
220 }
221
222 if let Some(port_index) = s2n_quic_core::path::remote_port_throttled_index(remote_port)
223 {
224 let rate_limiter = &mut self.rate_limiter[port_index];
225 if rate_limiter.should_throttle(THROTTLED_PORT_LIMIT, THROTTLE_FREQUENCY, info) {
226 return Outcome::drop();
227 }
228 }
229
230 if let Some(limit) = self.max_inflight_handshake_limit {
231 if info.inflight_handshakes >= limit {
232 return Outcome::retry();
233 }
234 }
235
236 Outcome::allow()
237 }
238 }
239
240 impl std::default::Default for Limits {
242 fn default() -> Self {
243 Self {
244 max_inflight_handshake_limit: None,
245 rate_limiter: [BasicRateLimiter::default(); THROTTLED_PORTS_LEN],
246 }
247 }
248 }
249
250 #[test]
251 fn builder_test() {
252 let elp = Limits::builder()
253 .with_inflight_handshake_limit(100)
254 .unwrap()
255 .build()
256 .unwrap();
257 assert_eq!(elp.max_inflight_handshake_limit, Some(100));
258 }
259
260 #[test]
261 fn blocked_port_connection_attempt() {
262 use s2n_quic_core::{
263 event::IntoEvent,
264 inet::SocketAddress,
265 time::{testing::Clock as MockClock, Clock},
266 };
267
268 let mut remote_address = SocketAddress::default();
269 let mut limits = Limits::builder().build().unwrap();
270 let mock_clock = MockClock::default();
271
272 for port in 0..u16::MAX {
273 let blocked_expected = s2n_quic_core::path::remote_port_blocked(port);
274
275 remote_address.set_port(port);
276 let info =
277 ConnectionAttempt::new(0, 0, &remote_address, mock_clock.get_time().into_event());
278 let outcome = limits.on_connection_attempt(&info);
279
280 if blocked_expected {
281 assert_eq!(Outcome::drop(), outcome);
282 } else {
283 assert_eq!(Outcome::allow(), outcome);
284 }
285 }
286 }
287}