s2n_quic_core/connection/
id.rs1use crate::{
7 event::{api::SocketAddress, IntoEvent},
8 inet, transport,
9};
10use core::time::Duration;
11use s2n_codec::{decoder_value, Encoder, EncoderValue};
12
13#[cfg(any(test, feature = "generator"))]
14use bolero_generator::prelude::*;
15
16pub const MAX_LEN: usize = crate::packet::long::DESTINATION_CONNECTION_ID_MAX_LEN;
24
25pub const MIN_LIFETIME: Duration = Duration::from_secs(60);
27
28pub const MAX_LIFETIME: Duration = Duration::from_secs(24 * 60 * 60); macro_rules! id {
33 ($type:ident, $min_len:expr) => {
34 #[derive(Copy, Clone, Eq)]
36 #[cfg_attr(any(feature = "generator", test), derive(TypeGenerator))]
37 pub struct $type {
38 bytes: [u8; MAX_LEN],
39 #[cfg_attr(any(feature = "generator", test), generator(Self::GENERATOR))]
40 len: u8,
41 }
42
43 impl core::fmt::Debug for $type {
44 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
45 write!(f, "{}({:?})", stringify!($type), self.as_bytes())
46 }
47 }
48
49 impl PartialEq for $type {
50 fn eq(&self, other: &Self) -> bool {
51 self.as_bytes() == other.as_bytes()
52 }
53 }
54
55 impl core::hash::Hash for $type {
56 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
57 self.as_bytes().hash(state);
58 }
59 }
60
61 impl PartialOrd for $type {
62 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
63 Some(self.cmp(other))
64 }
65 }
66
67 impl Ord for $type {
68 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
69 self.as_bytes().cmp(&other.as_bytes())
70 }
71 }
72
73 impl $type {
74 pub const MIN_LEN: usize = $min_len;
76
77 #[cfg(any(feature = "generator", test))]
78 const GENERATOR: core::ops::RangeInclusive<u8> = $min_len..=(MAX_LEN as u8);
79
80 #[inline]
86 pub fn try_from_bytes(bytes: &[u8]) -> Option<$type> {
87 Self::try_from(bytes).ok()
88 }
89
90 #[inline]
92 pub fn as_bytes(&self) -> &[u8] {
93 self.as_ref()
94 }
95
96 #[inline]
98 pub const fn len(&self) -> usize {
99 self.len as usize
100 }
101
102 #[inline]
104 pub fn is_empty(&self) -> bool {
105 self.len == 0
106 }
107
108 #[cfg(any(test, feature = "testing"))]
110 pub const TEST_ID: Self = Self::test_id();
111
112 #[cfg(any(test, feature = "testing"))]
116 const fn test_id() -> Self {
117 let type_bytes = stringify!($type).as_bytes();
118 let mut result = [0u8; MAX_LEN];
119 result[0] = type_bytes[0];
120 result[1] = type_bytes[1];
121 result[2] = type_bytes[2];
122 result[3] = type_bytes[3];
123 result[4] = type_bytes[4];
124 result[5] = type_bytes[5];
125 result[14] = type_bytes[0];
126 result[15] = type_bytes[1];
127 result[16] = type_bytes[2];
128 result[17] = type_bytes[3];
129 result[18] = type_bytes[4];
130 result[19] = type_bytes[5];
131 Self {
132 bytes: result,
133 len: MAX_LEN as u8,
134 }
135 }
136 }
137
138 impl From<[u8; MAX_LEN]> for $type {
139 #[inline]
140 fn from(bytes: [u8; MAX_LEN]) -> Self {
141 Self {
142 bytes,
143 len: MAX_LEN as u8,
144 }
145 }
146 }
147
148 impl TryFrom<&[u8]> for $type {
149 type Error = Error;
150
151 #[inline]
152 fn try_from(slice: &[u8]) -> Result<Self, Self::Error> {
153 let len = slice.len();
154 if !($type::MIN_LEN..=MAX_LEN).contains(&len) {
155 return Err(Error::InvalidLength);
156 }
157 let mut bytes = [0; MAX_LEN];
158 bytes[..len].copy_from_slice(slice);
159 Ok(Self {
160 bytes,
161 len: len as u8,
162 })
163 }
164 }
165
166 impl AsRef<[u8]> for $type {
167 #[inline]
168 fn as_ref(&self) -> &[u8] {
169 &self.bytes[0..self.len as usize]
170 }
171 }
172
173 decoder_value!(
174 impl<'a> $type {
175 fn decode(buffer: Buffer) -> Result<Self> {
176 let len = buffer.len();
177 let (value, buffer) = buffer.decode_slice(len)?;
178 let value: &[u8] = value.into_less_safe_slice();
179 let connection_id = $type::try_from(value).map_err(|_| {
180 s2n_codec::DecoderError::InvariantViolation(concat!(
181 "invalid ",
182 stringify!($type)
183 ))
184 })?;
185
186 Ok((connection_id, buffer))
187 }
188 }
189 );
190
191 impl EncoderValue for $type {
192 #[inline]
193 fn encode<E: Encoder>(&self, encoder: &mut E) {
194 self.as_ref().encode(encoder)
195 }
196 }
197
198 impl Default for $type {
201 fn default() -> Self {
202 unimplemented!("connection IDs do not have default values")
203 }
204 }
205 };
206}
207
208id!(LocalId, 4);
212
213id!(PeerId, 0);
216
217id!(UnboundedId, 0);
220id!(InitialId, 8);
228
229impl From<LocalId> for UnboundedId {
230 fn from(id: LocalId) -> Self {
231 UnboundedId {
232 bytes: id.bytes,
233 len: id.len,
234 }
235 }
236}
237
238impl From<PeerId> for UnboundedId {
239 fn from(id: PeerId) -> Self {
240 UnboundedId {
241 bytes: id.bytes,
242 len: id.len,
243 }
244 }
245}
246
247impl From<InitialId> for UnboundedId {
248 fn from(id: InitialId) -> Self {
249 UnboundedId {
250 bytes: id.bytes,
251 len: id.len,
252 }
253 }
254}
255
256impl From<InitialId> for PeerId {
257 fn from(id: InitialId) -> Self {
258 PeerId {
259 bytes: id.bytes,
260 len: id.len,
261 }
262 }
263}
264
265impl TryFrom<LocalId> for InitialId {
268 type Error = Error;
269
270 #[inline]
271 fn try_from(value: LocalId) -> Result<Self, Self::Error> {
272 value.as_bytes().try_into()
273 }
274}
275
276#[derive(Clone, Copy, Debug, PartialEq, Eq)]
277pub enum Classification {
278 Initial,
280 Local,
282}
283
284impl Classification {
285 #[inline]
286 pub fn is_initial(&self) -> bool {
287 matches!(self, Self::Initial)
288 }
289
290 #[inline]
291 pub fn is_local(&self) -> bool {
292 matches!(self, Self::Local)
293 }
294}
295
296#[derive(Debug, PartialEq)]
297pub enum Error {
298 InvalidLength,
299 InvalidLifetime,
300}
301
302impl Error {
303 fn message(&self) -> &'static str {
304 match self {
305 Error::InvalidLength => "invalid connection id length",
306 Error::InvalidLifetime => "invalid connection id lifetime",
307 }
308 }
309}
310
311impl core::fmt::Display for Error {
312 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
313 write!(f, "{}", self.message())
314 }
315}
316
317impl From<Error> for transport::Error {
318 #[inline]
319 fn from(error: Error) -> Self {
320 Self::PROTOCOL_VIOLATION.with_reason(error.message())
321 }
322}
323
324#[derive(Clone, Debug)]
327#[non_exhaustive]
328pub struct ConnectionInfo<'a> {
329 pub remote_address: SocketAddress<'a>,
330}
331
332impl<'a> ConnectionInfo<'a> {
333 #[inline]
334 #[doc(hidden)]
335 pub fn new(remote_address: &'a inet::SocketAddress) -> Self {
336 Self {
337 remote_address: remote_address.into_event(),
338 }
339 }
340}
341
342pub trait Format: 'static + Validator + Generator + Send {}
344
345impl<T: 'static + Validator + Generator + Send> Format for T {}
347
348pub trait Validator {
350 fn validate(&self, connection_info: &ConnectionInfo, buffer: &[u8]) -> Option<usize>;
362}
363
364impl Validator for usize {
365 #[inline]
366 fn validate(&self, _connection_info: &ConnectionInfo, buffer: &[u8]) -> Option<usize> {
367 if buffer.len() >= *self {
368 Some(*self)
369 } else {
370 None
371 }
372 }
373}
374
375pub trait Generator {
377 fn generate(&mut self, connection_info: &ConnectionInfo) -> LocalId;
387
388 #[inline]
392 fn lifetime(&self) -> Option<core::time::Duration> {
393 None
394 }
395
396 #[inline]
401 fn rotate_handshake_connection_id(&self) -> bool {
402 true
403 }
404}
405
406#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord)]
407pub enum Interest {
408 #[default]
410 None,
411 New(u8),
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418
419 #[test]
420 fn create_connection_id() {
421 let connection_id = LocalId::try_from_bytes(b"My Connection 123").unwrap();
422 assert_eq!(b"My Connection 123", connection_id.as_bytes());
423
424 let connection_id = PeerId::try_from_bytes(b"My Connection 456").unwrap();
425 assert_eq!(b"My Connection 456", connection_id.as_bytes());
426
427 let connection_id = InitialId::try_from_bytes(b"My Connection 789").unwrap();
428 assert_eq!(b"My Connection 789", connection_id.as_bytes());
429 }
430
431 #[test]
432 fn exceed_max_connection_id_length() {
433 let connection_id_bytes = [0u8; MAX_LEN];
434 assert!(LocalId::try_from_bytes(&connection_id_bytes).is_some());
435 assert!(PeerId::try_from_bytes(&connection_id_bytes).is_some());
436 assert!(InitialId::try_from_bytes(&connection_id_bytes).is_some());
437
438 let connection_id_bytes = [0u8; MAX_LEN + 1];
439 assert!(LocalId::try_from_bytes(&connection_id_bytes).is_none());
440 assert!(PeerId::try_from_bytes(&connection_id_bytes).is_none());
441 assert!(InitialId::try_from_bytes(&connection_id_bytes).is_none());
442 }
443
444 #[test]
445 fn min_connection_id_length() {
446 let connection_id_bytes = [0u8; LocalId::MIN_LEN];
447 assert!(LocalId::try_from_bytes(&connection_id_bytes).is_some());
448
449 let connection_id_bytes = [0u8; PeerId::MIN_LEN];
450 assert!(PeerId::try_from_bytes(&connection_id_bytes).is_some());
451
452 let connection_id_bytes = [0u8; InitialId::MIN_LEN];
453 assert!(InitialId::try_from_bytes(&connection_id_bytes).is_some());
454
455 let connection_id_bytes = [0u8; LocalId::MIN_LEN - 1];
456 assert!(LocalId::try_from_bytes(&connection_id_bytes).is_none());
457
458 let connection_id_bytes = [0u8; InitialId::MIN_LEN - 1];
459 assert!(InitialId::try_from_bytes(&connection_id_bytes).is_none());
460 }
461
462 #[test]
463 fn unbounded_id() {
464 let connection_id_bytes = [0u8; LocalId::MIN_LEN];
465 assert!(UnboundedId::try_from_bytes(&connection_id_bytes).is_some());
466
467 let connection_id_bytes = [0u8; PeerId::MIN_LEN];
468 assert!(UnboundedId::try_from_bytes(&connection_id_bytes).is_some());
469
470 let connection_id_bytes = [0u8; InitialId::MIN_LEN];
471 assert!(UnboundedId::try_from_bytes(&connection_id_bytes).is_some());
472
473 println!("{:?}", LocalId::TEST_ID);
474 println!("{:?}", PeerId::TEST_ID);
475 println!("{:?}", UnboundedId::TEST_ID);
476 }
477}
478
479#[cfg(any(test, feature = "testing"))]
480pub mod testing {
481 use super::*;
482 use core::convert::TryInto;
483
484 #[derive(Debug, Default)]
485 pub struct Format(u64);
486
487 impl Validator for Format {
488 fn validate(&self, _connection_info: &ConnectionInfo, _buffer: &[u8]) -> Option<usize> {
489 Some(core::mem::size_of::<u64>())
490 }
491 }
492
493 impl Generator for Format {
494 fn generate(&mut self, _connection_info: &ConnectionInfo) -> LocalId {
495 let id = (&self.0.to_be_bytes()[..]).try_into().unwrap();
496 self.0 += 1;
497 id
498 }
499 }
500}