s2n_quic/provider/
endpoint_limits.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Allows applications to limit peer's ability to open new connections
5
6pub use s2n_quic_core::endpoint::{
7    limits::{ConnectionAttempt, Outcome},
8    Limiter,
9};
10use s2n_quic_core::{event::Timestamp, path::THROTTLED_PORTS_LEN};
11
12pub trait Provider: 'static {
13    type Limits: 'static + Limiter;
14    type Error: core::fmt::Display + Send + Sync;
15
16    /// Starts the token provider
17    fn start(self) -> Result<Self::Limits, Self::Error>;
18}
19
20use core::time::Duration;
21pub use default::Limits as Default;
22
23impl_provider_utils!();
24
25impl<T: 'static + Limiter> Provider for T {
26    type Limits = T;
27    type Error = core::convert::Infallible;
28
29    fn start(self) -> Result<Self::Limits, Self::Error> {
30        Ok(self)
31    }
32}
33
34const THROTTLED_PORT_LIMIT: usize = 10;
35const THROTTLE_FREQUENCY: Duration = Duration::from_secs(1);
36
37#[derive(Default, Debug, Clone, Copy)]
38struct BasicRateLimiter {
39    last_throttle_reset: Option<Timestamp>,
40    count: usize,
41}
42
43impl BasicRateLimiter {
44    /// Throttles a connection based on a limit.
45    /// Returns True if the `should_throttle` invoke count is greater than `limit` and the
46    /// connection has not been throttled in the last `throttle_frequency` duration.
47    ///
48    /// If the throttle timer expires the count is reset and `should_throttle` returns False.
49    ///
50    /// Returns False if the `should_throttle` invoke count is less than `limit`.
51    fn should_throttle(
52        &mut self,
53        limit: usize,
54        throttle_frequency: Duration,
55        connection_attempt: &ConnectionAttempt,
56    ) -> bool {
57        self.count += 1;
58        let timestamp = connection_attempt.timestamp;
59
60        if self.count > limit {
61            match self.last_throttle_reset {
62                // If the throttle timer is still within the throttle_frequency
63                // then throttle the connection.
64                Some(last_throttle_reset)
65                    if timestamp.saturating_duration_since(last_throttle_reset)
66                        < throttle_frequency =>
67                {
68                    return true;
69                }
70                // If the throttle timer is greater than the throttle_frequency
71                // then reset the throttle count and the throttle timer.
72                // Let the connection through.
73                _ => {
74                    self.count = 0;
75                    self.last_throttle_reset = Some(timestamp);
76                    return false;
77                }
78            };
79        }
80
81        // If this is the first time calling instantiate the throttle timer.
82        if self.last_throttle_reset.is_none() {
83            self.last_throttle_reset = Some(timestamp);
84        }
85
86        false
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::{BasicRateLimiter, THROTTLED_PORT_LIMIT, THROTTLE_FREQUENCY};
93    use core::time::Duration;
94    use s2n_quic_core::{
95        endpoint::limits::ConnectionAttempt,
96        event::IntoEvent,
97        inet::SocketAddress,
98        time::{testing::Clock as MockClock, Clock},
99    };
100
101    #[test]
102    fn first_throttle_reset() {
103        let remote_address = SocketAddress::default();
104        let mock_clock = MockClock::default();
105        let info =
106            ConnectionAttempt::new(0, 0, &remote_address, mock_clock.get_time().into_event());
107
108        let mut rate_limiter = BasicRateLimiter::default();
109        // The first time the throttle limit is hit the timer will be created so we expect to be
110        // able to connect THROTTLED_PORT_LIMIT amount of times before the connection is throttled.
111        // Note: This test should run fast enough to not let the timer reset.
112        let very_long_freq = Duration::MAX;
113
114        for request in 0..(THROTTLED_PORT_LIMIT * 3) {
115            if request >= THROTTLED_PORT_LIMIT {
116                assert!(rate_limiter.should_throttle(THROTTLED_PORT_LIMIT, very_long_freq, &info));
117            } else {
118                assert!(!rate_limiter.should_throttle(THROTTLED_PORT_LIMIT, very_long_freq, &info));
119            }
120        }
121    }
122
123    #[test]
124    fn throttle_timer_reset() {
125        let remote_address = SocketAddress::default();
126        let mut mock_clock = MockClock::default();
127
128        let mut rate_limiter = BasicRateLimiter::default();
129        let short_freq = Duration::from_millis(10);
130        let sleep_longer_than_short_freq = Duration::from_millis(500);
131
132        // This test should never throttle because everytime the limit is about to get hit the
133        // thread sleeps long enough for the throttle reset timer to fire.
134        for request in 0..(THROTTLED_PORT_LIMIT * 3) {
135            let info =
136                ConnectionAttempt::new(0, 0, &remote_address, mock_clock.get_time().into_event());
137            if request % THROTTLED_PORT_LIMIT == 0 {
138                mock_clock.inc_by(sleep_longer_than_short_freq)
139            }
140            assert!(!rate_limiter.should_throttle(THROTTLED_PORT_LIMIT, short_freq, &info));
141        }
142    }
143
144    #[test]
145    fn throttle_constants_changed() {
146        // If the constants change consider modifying the above test cases to make sure we are
147        // confident that we are hitting all the correct conditions.
148        //
149        // For example if we increase THROTTLE_FREQUENCY to a very large period, do the above
150        // tests still make sense?
151        assert_eq!(THROTTLED_PORT_LIMIT, 10);
152        assert_eq!(THROTTLE_FREQUENCY, Duration::from_secs(1));
153    }
154}
155
156pub mod default {
157    //! Default provider for the endpoint limits.
158
159    use super::*;
160    use core::convert::Infallible;
161
162    /// Allows the endpoint limits to be built with specific values
163    ///
164    /// # Examples
165    ///
166    /// Set the maximum inflight handshakes for this endpoint.
167    ///
168    /// ```rust
169    /// use s2n_quic::provider::endpoint_limits;
170    /// # use std::error::Error;
171    /// # #[tokio::main]
172    /// # async fn main() -> Result<(), Box<dyn Error>> {
173    /// let limits = endpoint_limits::Default::builder()
174    ///     .with_inflight_handshake_limit(100)?
175    ///     .build();
176    ///
177    ///     Ok(())
178    /// # }
179    /// ```
180    #[derive(Default)]
181    pub struct Builder {
182        max_inflight_handshake_limit: Option<usize>,
183    }
184
185    impl Builder {
186        /// Sets limit on inflight handshakes
187        pub fn with_inflight_handshake_limit(mut self, limit: usize) -> Result<Self, Infallible> {
188            self.max_inflight_handshake_limit = Some(limit);
189            Ok(self)
190        }
191
192        /// Build the limits
193        pub fn build(self) -> Result<Limits, Infallible> {
194            Ok(Limits {
195                max_inflight_handshake_limit: self.max_inflight_handshake_limit,
196                rate_limiter: [BasicRateLimiter::default(); THROTTLED_PORTS_LEN],
197            })
198        }
199    }
200
201    #[derive(Clone, Copy, Debug)]
202    pub struct Limits {
203        /// Maximum number of handshakes to allow before Retry packets are queued
204        max_inflight_handshake_limit: Option<usize>,
205        rate_limiter: [BasicRateLimiter; THROTTLED_PORTS_LEN],
206    }
207
208    impl Limits {
209        pub fn builder() -> Builder {
210            Builder::default()
211        }
212    }
213
214    /// Default implementation for the Limits
215    impl super::Limiter for Limits {
216        fn on_connection_attempt(&mut self, info: &ConnectionAttempt) -> Outcome {
217            let remote_port = info.remote_address.port();
218            if s2n_quic_core::path::remote_port_blocked(remote_port) {
219                return Outcome::drop();
220            }
221
222            if let Some(port_index) = s2n_quic_core::path::remote_port_throttled_index(remote_port)
223            {
224                let rate_limiter = &mut self.rate_limiter[port_index];
225                if rate_limiter.should_throttle(THROTTLED_PORT_LIMIT, THROTTLE_FREQUENCY, info) {
226                    return Outcome::drop();
227                }
228            }
229
230            if let Some(limit) = self.max_inflight_handshake_limit {
231                if info.inflight_handshakes >= limit {
232                    return Outcome::retry();
233                }
234            }
235
236            Outcome::allow()
237        }
238    }
239
240    /// Default limit values are as non-intrusive as possible
241    impl std::default::Default for Limits {
242        fn default() -> Self {
243            Self {
244                max_inflight_handshake_limit: None,
245                rate_limiter: [BasicRateLimiter::default(); THROTTLED_PORTS_LEN],
246            }
247        }
248    }
249
250    #[test]
251    fn builder_test() {
252        let elp = Limits::builder()
253            .with_inflight_handshake_limit(100)
254            .unwrap()
255            .build()
256            .unwrap();
257        assert_eq!(elp.max_inflight_handshake_limit, Some(100));
258    }
259
260    #[test]
261    fn blocked_port_connection_attempt() {
262        use s2n_quic_core::{
263            event::IntoEvent,
264            inet::SocketAddress,
265            time::{testing::Clock as MockClock, Clock},
266        };
267
268        let mut remote_address = SocketAddress::default();
269        let mut limits = Limits::builder().build().unwrap();
270        let mock_clock = MockClock::default();
271
272        for port in 0..u16::MAX {
273            let blocked_expected = s2n_quic_core::path::remote_port_blocked(port);
274
275            remote_address.set_port(port);
276            let info =
277                ConnectionAttempt::new(0, 0, &remote_address, mock_clock.get_time().into_event());
278            let outcome = limits.on_connection_attempt(&info);
279
280            if blocked_expected {
281                assert_eq!(Outcome::drop(), outcome);
282            } else {
283                assert_eq!(Outcome::allow(), outcome);
284            }
285        }
286    }
287}