s2n_quic_core/packet/number/
packet_number.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    event::IntoEvent,
6    packet::number::{
7        derive_truncation_range, packet_number_space::PacketNumberSpace,
8        truncated_packet_number::TruncatedPacketNumber,
9    },
10    varint::VarInt,
11};
12use core::{
13    cmp::Ordering,
14    fmt,
15    hash::{Hash, Hasher},
16    mem::size_of,
17    num::NonZeroU64,
18};
19
20#[cfg(any(test, feature = "generator"))]
21use bolero_generator::prelude::*;
22
23const PACKET_SPACE_BITLEN: usize = 2;
24const PACKET_SPACE_SHIFT: usize = (size_of::<PacketNumber>() * 8) - PACKET_SPACE_BITLEN;
25const PACKET_NUMBER_MASK: u64 = u64::MAX >> PACKET_SPACE_BITLEN;
26
27/// Contains a fully-decoded packet number in a given space
28///
29/// Internally the packet number is represented as a [`NonZeroU64`]
30/// to ensure optimal memory layout.
31///
32/// The lower 62 bits are used to store the actual packet number value.
33/// The upper 2 bits are used to store the packet number space. Because
34/// there are only 3 spaces, the zero state is never used, which is why
35/// [`NonZeroU64`] can be used instead of `u64`.
36#[derive(Clone, Copy, Eq)]
37#[cfg_attr(any(test, feature = "generator"), derive(TypeGenerator))]
38pub struct PacketNumber(NonZeroU64);
39
40impl IntoEvent<u64> for PacketNumber {
41    #[inline]
42    fn into_event(self) -> u64 {
43        self.as_u64()
44    }
45}
46
47impl Default for PacketNumber {
48    fn default() -> Self {
49        Self::from_varint(Default::default(), PacketNumberSpace::Initial)
50    }
51}
52
53impl Hash for PacketNumber {
54    #[inline]
55    fn hash<H: Hasher>(&self, state: &mut H) {
56        self.0.hash(state)
57    }
58}
59
60impl PartialEq for PacketNumber {
61    #[inline]
62    fn eq(&self, other: &Self) -> bool {
63        self.cmp(other) == Ordering::Equal
64    }
65}
66
67impl PartialOrd for PacketNumber {
68    #[inline]
69    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
70        Some(self.cmp(other))
71    }
72}
73
74impl Ord for PacketNumber {
75    #[inline]
76    fn cmp(&self, other: &Self) -> Ordering {
77        if cfg!(debug_assertions) {
78            self.space().assert_eq(other.space());
79        }
80        self.0.cmp(&other.0)
81    }
82}
83
84impl fmt::Debug for PacketNumber {
85    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
86        f.debug_tuple("PacketNumber")
87            .field(&self.space())
88            .field(&self.as_u64())
89            .finish()
90    }
91}
92
93impl fmt::Display for PacketNumber {
94    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
95        self.as_u64().fmt(f)
96    }
97}
98
99impl PacketNumber {
100    /// Creates a PacketNumber for a given VarInt and PacketNumberSpace
101    #[inline]
102    pub(crate) const fn from_varint(value: VarInt, space: PacketNumberSpace) -> Self {
103        let tag = space.as_tag() as u64;
104        let pn = (tag << PACKET_SPACE_SHIFT) | value.as_u64();
105        let pn = unsafe {
106            // Safety: packet number space tag is never 0
107            NonZeroU64::new_unchecked(pn)
108        };
109        Self(pn)
110    }
111
112    /// Returns the `PacketNumberSpace` for the given `PacketNumber`
113    #[inline]
114    pub fn space(self) -> PacketNumberSpace {
115        let tag = self.0.get() >> PACKET_SPACE_SHIFT;
116        PacketNumberSpace::from_tag(tag as u8)
117    }
118
119    /// Converts the `PacketNumber` into a `VarInt` value.
120    ///
121    /// Note: Even though some scenarios require this function, it should be
122    /// avoided in most cases, as it removes the corresponding `PacketNumberSpace`
123    /// and allows math operations to be performed, which can easily result in
124    /// protocol errors.
125    #[allow(clippy::wrong_self_convention)] // Don't use `self` here to make conversion explicit
126    pub const fn as_varint(packet_number: Self) -> VarInt {
127        // Safety: when converting to a u64, we remove the top 2 bits which
128        //         will force the value to fit into a VarInt.
129        unsafe { VarInt::new_unchecked(packet_number.as_u64()) }
130    }
131
132    /// Truncates the `PacketNumber` into a `TruncatedPacketNumber` based on
133    /// the largest acknowledged packet number
134    #[inline]
135    pub fn truncate(
136        self,
137        largest_acknowledged_packet_number: Self,
138    ) -> Option<TruncatedPacketNumber> {
139        Some(
140            derive_truncation_range(largest_acknowledged_packet_number, self)?
141                .truncate_packet_number(Self::as_varint(self)),
142        )
143    }
144
145    /// Compute the next packet number in the space. If the packet number has
146    /// exceeded the maximum value allowed `None` will be returned.
147    #[inline]
148    pub fn next(self) -> Option<Self> {
149        let value = Self::as_varint(self).checked_add(VarInt::from_u8(1))?;
150        let space = self.space();
151        Some(Self::from_varint(value, space))
152    }
153
154    /// Compute the prev packet number in the space. If the packet number has
155    /// underflowed `None` will be returned.
156    #[inline]
157    pub fn prev(self) -> Option<Self> {
158        let value = Self::as_varint(self).checked_sub(VarInt::from_u8(1))?;
159        let space = self.space();
160        Some(Self::from_varint(value, space))
161    }
162
163    /// Create a nonce for crypto from the packet number value
164    ///
165    /// Note: This should not be used by anything other than crypto-related
166    /// functionality.
167    #[inline]
168    pub const fn as_crypto_nonce(self) -> u64 {
169        self.as_u64()
170    }
171
172    /// Returns the value with the top 2 bits removed
173    #[inline]
174    pub const fn as_u64(self) -> u64 {
175        self.0.get() & PACKET_NUMBER_MASK
176    }
177
178    /// Computes the distance between this packet number and the given packet number,
179    /// returning None if overflow occurred.
180    #[inline]
181    pub fn checked_distance(self, rhs: PacketNumber) -> Option<u64> {
182        self.space().assert_eq(rhs.space());
183        Self::as_u64(self).checked_sub(Self::as_u64(rhs))
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    /// Make sure the assumptions around packet number space tags holds true
192    #[test]
193    fn packet_number_space_assumptions_test() {
194        assert!(PacketNumberSpace::Initial.as_tag() != 0);
195        assert!(PacketNumberSpace::Handshake.as_tag() != 0);
196        assert!(PacketNumberSpace::ApplicationData.as_tag() != 0);
197    }
198
199    #[test]
200    fn round_trip_test() {
201        let spaces = [
202            PacketNumberSpace::Initial,
203            PacketNumberSpace::Handshake,
204            PacketNumberSpace::ApplicationData,
205        ];
206
207        let values = [
208            VarInt::from_u8(0),
209            VarInt::from_u8(1),
210            VarInt::from_u8(2),
211            VarInt::from_u8(u8::MAX / 2),
212            VarInt::from_u8(u8::MAX - 1),
213            VarInt::from_u8(u8::MAX),
214            VarInt::from_u16(u16::MAX / 2),
215            VarInt::from_u16(u16::MAX - 1),
216            VarInt::from_u16(u16::MAX),
217            VarInt::from_u32(u32::MAX / 2),
218            VarInt::from_u32(u32::MAX - 1),
219            VarInt::from_u32(u32::MAX),
220            VarInt::MAX,
221        ];
222
223        for space in spaces.iter().cloned() {
224            for value in values.iter().cloned() {
225                let pn = PacketNumber::from_varint(value, space);
226                assert_eq!(pn.space(), space, "{:#064b}", pn.0);
227                assert_eq!(PacketNumber::as_varint(pn), value, "{:#064b}", pn.0);
228            }
229        }
230    }
231    #[test]
232    #[should_panic]
233    fn wrong_packet_number_space() {
234        PacketNumberSpace::ApplicationData
235            .new_packet_number(VarInt::from_u8(0))
236            .checked_distance(PacketNumberSpace::Handshake.new_packet_number(VarInt::from_u8(0)));
237    }
238}