s2n_quic/provider/
connection_id.rs1pub use s2n_quic_core::connection::id::{ConnectionInfo, Format, Generator, LocalId, Validator};
7
8pub trait Provider: 'static {
9 type Format: 'static + Format;
10 type Error: core::fmt::Display + Send + Sync;
11
12 fn start(self) -> Result<Self::Format, Self::Error>;
13}
14
15pub use default::Provider as Default;
16
17impl_provider_utils!();
18
19impl<T: 'static + Format> Provider for T {
20 type Format = T;
21 type Error = core::convert::Infallible;
22
23 fn start(self) -> Result<Self::Format, Self::Error> {
24 Ok(self)
25 }
26}
27
28pub mod default {
29 use core::{
30 convert::{Infallible, TryInto},
31 time::Duration,
32 };
33 use rand::prelude::*;
34 use s2n_quic_core::connection::{
35 self,
36 id::{ConnectionInfo, Generator, Validator},
37 };
38
39 #[derive(Debug, Default)]
40 pub struct Provider(Format);
41
42 impl super::Provider for Provider {
43 type Format = Format;
44 type Error = Infallible;
45
46 fn start(self) -> Result<Self::Format, Self::Error> {
47 Ok(self.0)
48 }
49 }
50
51 const DEFAULT_LEN: usize = 16;
53
54 #[derive(Debug)]
58 pub struct Format {
59 len: usize,
60 lifetime: Option<Duration>,
61 rotate_handshake_connection_id: bool,
62 }
63
64 impl Default for Format {
65 fn default() -> Self {
66 Self {
67 len: DEFAULT_LEN,
68 lifetime: None,
69 rotate_handshake_connection_id: true,
70 }
71 }
72 }
73
74 impl Format {
75 pub fn builder() -> Builder {
77 Builder::default()
78 }
79 }
80
81 #[derive(Debug)]
83 pub struct Builder {
84 len: usize,
85 lifetime: Option<Duration>,
86 rotate_handshake_connection_id: bool,
87 }
88
89 impl Default for Builder {
90 fn default() -> Self {
91 Self {
92 len: DEFAULT_LEN,
93 lifetime: None,
94 rotate_handshake_connection_id: true,
95 }
96 }
97 }
98
99 impl Builder {
100 pub fn with_len(mut self, len: usize) -> Result<Self, connection::id::Error> {
102 if !(connection::LocalId::MIN_LEN..=connection::id::MAX_LEN).contains(&len) {
103 return Err(connection::id::Error::InvalidLength);
104 }
105 self.len = len;
106 Ok(self)
107 }
108
109 pub fn with_lifetime(mut self, lifetime: Duration) -> Result<Self, connection::id::Error> {
111 if !(connection::id::MIN_LIFETIME..=connection::id::MAX_LIFETIME).contains(&lifetime) {
112 return Err(connection::id::Error::InvalidLifetime);
113 }
114 self.lifetime = Some(lifetime);
115 Ok(self)
116 }
117
118 pub fn with_handshake_connection_id_rotation(
125 mut self,
126 enabled: bool,
127 ) -> Result<Self, core::convert::Infallible> {
128 self.rotate_handshake_connection_id = enabled;
129 Ok(self)
130 }
131
132 pub fn build(self) -> Result<Format, core::convert::Infallible> {
134 Ok(Format {
135 len: self.len,
136 lifetime: self.lifetime,
137 rotate_handshake_connection_id: self.rotate_handshake_connection_id,
138 })
139 }
140 }
141
142 impl Generator for Format {
143 fn generate(&mut self, _connection_info: &ConnectionInfo) -> connection::LocalId {
144 let mut id = [0u8; connection::id::MAX_LEN];
145 let id = &mut id[..self.len];
146 rand::rng().fill_bytes(id);
147 (&*id).try_into().expect("length already checked")
148 }
149
150 fn lifetime(&self) -> Option<Duration> {
151 self.lifetime
152 }
153
154 fn rotate_handshake_connection_id(&self) -> bool {
155 self.rotate_handshake_connection_id
156 }
157 }
158
159 impl Validator for Format {
160 fn validate(&self, _connection_info: &ConnectionInfo, buffer: &[u8]) -> Option<usize> {
161 if buffer.len() >= self.len {
162 Some(self.len)
163 } else {
164 None
165 }
166 }
167 }
168
169 #[cfg(test)]
170 mod tests {
171 use super::*;
172
173 #[test]
174 fn generator_test() {
175 let remote_address = &s2n_quic_core::inet::SocketAddress::default();
176 let connection_info = ConnectionInfo::new(remote_address);
177
178 for len in connection::LocalId::MIN_LEN..connection::id::MAX_LEN {
179 let mut format = Format::builder().with_len(len).unwrap().build().unwrap();
180
181 let id = format.generate(&connection_info);
182
183 assert_eq!(format.validate(&connection_info, id.as_ref()), Some(len));
190 assert_eq!(id.len(), len);
191 assert_eq!(format.lifetime(), None);
192 assert!(format.rotate_handshake_connection_id());
193 }
194
195 assert_eq!(
196 Some(connection::id::Error::InvalidLength),
197 Format::builder()
198 .with_len(connection::id::MAX_LEN + 1)
199 .err()
200 );
201
202 assert_eq!(
203 Some(connection::id::Error::InvalidLength),
204 Format::builder()
205 .with_len(connection::LocalId::MIN_LEN - 1)
206 .err()
207 );
208
209 let lifetime = Duration::from_secs(1000);
210 let format = Format::builder()
211 .with_lifetime(lifetime)
212 .unwrap()
213 .build()
214 .unwrap();
215 assert_eq!(Some(lifetime), format.lifetime());
216 assert!(format.rotate_handshake_connection_id());
217
218 assert_eq!(
219 Some(connection::id::Error::InvalidLifetime),
220 Format::builder()
221 .with_lifetime(connection::id::MIN_LIFETIME - Duration::from_millis(1))
222 .err()
223 );
224
225 assert_eq!(
226 Some(connection::id::Error::InvalidLifetime),
227 Format::builder()
228 .with_lifetime(connection::id::MAX_LIFETIME + Duration::from_millis(1))
229 .err()
230 );
231
232 let format = Format::builder().build().unwrap();
233 assert!(format.rotate_handshake_connection_id());
234
235 let format = Format::builder()
236 .with_handshake_connection_id_rotation(true)
237 .unwrap()
238 .build()
239 .unwrap();
240 assert!(format.rotate_handshake_connection_id());
241
242 let format = Format::builder()
243 .with_handshake_connection_id_rotation(false)
244 .unwrap()
245 .build()
246 .unwrap();
247 assert!(!format.rotate_handshake_connection_id());
248 }
249 }
250}