s2n_quic/provider/
connection_id.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Provides connection id support for an endpoint
5
6pub use s2n_quic_core::connection::id::{ConnectionInfo, Format, Generator, LocalId, Validator};
7
8pub trait Provider: 'static {
9    type Format: 'static + Format;
10    type Error: core::fmt::Display + Send + Sync;
11
12    fn start(self) -> Result<Self::Format, Self::Error>;
13}
14
15pub use default::Provider as Default;
16
17impl_provider_utils!();
18
19impl<T: 'static + Format> Provider for T {
20    type Format = T;
21    type Error = core::convert::Infallible;
22
23    fn start(self) -> Result<Self::Format, Self::Error> {
24        Ok(self)
25    }
26}
27
28pub mod default {
29    use core::{
30        convert::{Infallible, TryInto},
31        time::Duration,
32    };
33    use rand::prelude::*;
34    use s2n_quic_core::connection::{
35        self,
36        id::{ConnectionInfo, Generator, Validator},
37    };
38
39    #[derive(Debug, Default)]
40    pub struct Provider(Format);
41
42    impl super::Provider for Provider {
43        type Format = Format;
44        type Error = Infallible;
45
46        fn start(self) -> Result<Self::Format, Self::Error> {
47            Ok(self.0)
48        }
49    }
50
51    /// 16 bytes should be big enough for a randomly generated Id
52    const DEFAULT_LEN: usize = 16;
53
54    /// Randomly generated connection Id format.
55    ///
56    /// By default, connection Ids of length 16 bytes are generated.
57    #[derive(Debug)]
58    pub struct Format {
59        len: usize,
60        lifetime: Option<Duration>,
61        rotate_handshake_connection_id: bool,
62    }
63
64    impl Default for Format {
65        fn default() -> Self {
66            Self {
67                len: DEFAULT_LEN,
68                lifetime: None,
69                rotate_handshake_connection_id: true,
70            }
71        }
72    }
73
74    impl Format {
75        /// Creates a builder for the format
76        pub fn builder() -> Builder {
77            Builder::default()
78        }
79    }
80
81    /// A builder for [`Format`] providers
82    #[derive(Debug)]
83    pub struct Builder {
84        len: usize,
85        lifetime: Option<Duration>,
86        rotate_handshake_connection_id: bool,
87    }
88
89    impl Default for Builder {
90        fn default() -> Self {
91            Self {
92                len: DEFAULT_LEN,
93                lifetime: None,
94                rotate_handshake_connection_id: true,
95            }
96        }
97    }
98
99    impl Builder {
100        /// Sets the length of the generated connection Id
101        pub fn with_len(mut self, len: usize) -> Result<Self, connection::id::Error> {
102            if !(connection::LocalId::MIN_LEN..=connection::id::MAX_LEN).contains(&len) {
103                return Err(connection::id::Error::InvalidLength);
104            }
105            self.len = len;
106            Ok(self)
107        }
108
109        /// Sets the lifetime of each generated connection Id
110        pub fn with_lifetime(mut self, lifetime: Duration) -> Result<Self, connection::id::Error> {
111            if !(connection::id::MIN_LIFETIME..=connection::id::MAX_LIFETIME).contains(&lifetime) {
112                return Err(connection::id::Error::InvalidLifetime);
113            }
114            self.lifetime = Some(lifetime);
115            Ok(self)
116        }
117
118        /// Enables/disables rotation of the connection Id used during the handshake (default: enabled)
119        ///
120        /// When enabled (the default), the connection ID used during the the handshake
121        /// will be requested to be retired following confirmation of the handshake
122        /// completing. This reduces linkability between information exchanged
123        /// during and after the handshake.
124        pub fn with_handshake_connection_id_rotation(
125            mut self,
126            enabled: bool,
127        ) -> Result<Self, core::convert::Infallible> {
128            self.rotate_handshake_connection_id = enabled;
129            Ok(self)
130        }
131
132        /// Builds the [`Format`] into a provider
133        pub fn build(self) -> Result<Format, core::convert::Infallible> {
134            Ok(Format {
135                len: self.len,
136                lifetime: self.lifetime,
137                rotate_handshake_connection_id: self.rotate_handshake_connection_id,
138            })
139        }
140    }
141
142    impl Generator for Format {
143        fn generate(&mut self, _connection_info: &ConnectionInfo) -> connection::LocalId {
144            let mut id = [0u8; connection::id::MAX_LEN];
145            let id = &mut id[..self.len];
146            rand::rng().fill_bytes(id);
147            (&*id).try_into().expect("length already checked")
148        }
149
150        fn lifetime(&self) -> Option<Duration> {
151            self.lifetime
152        }
153
154        fn rotate_handshake_connection_id(&self) -> bool {
155            self.rotate_handshake_connection_id
156        }
157    }
158
159    impl Validator for Format {
160        fn validate(&self, _connection_info: &ConnectionInfo, buffer: &[u8]) -> Option<usize> {
161            if buffer.len() >= self.len {
162                Some(self.len)
163            } else {
164                None
165            }
166        }
167    }
168
169    #[cfg(test)]
170    mod tests {
171        use super::*;
172
173        #[test]
174        fn generator_test() {
175            let remote_address = &s2n_quic_core::inet::SocketAddress::default();
176            let connection_info = ConnectionInfo::new(remote_address);
177
178            for len in connection::LocalId::MIN_LEN..connection::id::MAX_LEN {
179                let mut format = Format::builder().with_len(len).unwrap().build().unwrap();
180
181                let id = format.generate(&connection_info);
182
183                //= https://www.rfc-editor.org/rfc/rfc9000#section-10.3.2
184                //= type=test
185                //# An endpoint that uses this design MUST
186                //# either use the same connection ID length for all connections or
187                //# encode the length of the connection ID such that it can be recovered
188                //# without state.
189                assert_eq!(format.validate(&connection_info, id.as_ref()), Some(len));
190                assert_eq!(id.len(), len);
191                assert_eq!(format.lifetime(), None);
192                assert!(format.rotate_handshake_connection_id());
193            }
194
195            assert_eq!(
196                Some(connection::id::Error::InvalidLength),
197                Format::builder()
198                    .with_len(connection::id::MAX_LEN + 1)
199                    .err()
200            );
201
202            assert_eq!(
203                Some(connection::id::Error::InvalidLength),
204                Format::builder()
205                    .with_len(connection::LocalId::MIN_LEN - 1)
206                    .err()
207            );
208
209            let lifetime = Duration::from_secs(1000);
210            let format = Format::builder()
211                .with_lifetime(lifetime)
212                .unwrap()
213                .build()
214                .unwrap();
215            assert_eq!(Some(lifetime), format.lifetime());
216            assert!(format.rotate_handshake_connection_id());
217
218            assert_eq!(
219                Some(connection::id::Error::InvalidLifetime),
220                Format::builder()
221                    .with_lifetime(connection::id::MIN_LIFETIME - Duration::from_millis(1))
222                    .err()
223            );
224
225            assert_eq!(
226                Some(connection::id::Error::InvalidLifetime),
227                Format::builder()
228                    .with_lifetime(connection::id::MAX_LIFETIME + Duration::from_millis(1))
229                    .err()
230            );
231
232            let format = Format::builder().build().unwrap();
233            assert!(format.rotate_handshake_connection_id());
234
235            let format = Format::builder()
236                .with_handshake_connection_id_rotation(true)
237                .unwrap()
238                .build()
239                .unwrap();
240            assert!(format.rotate_handshake_connection_id());
241
242            let format = Format::builder()
243                .with_handshake_connection_id_rotation(false)
244                .unwrap()
245                .build()
246                .unwrap();
247            assert!(!format.rotate_handshake_connection_id());
248        }
249    }
250}