1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use crate::{
    event::IntoEvent,
    packet::number::{
        derive_truncation_range, packet_number_space::PacketNumberSpace,
        truncated_packet_number::TruncatedPacketNumber,
    },
    varint::VarInt,
};
use core::{
    cmp::Ordering,
    fmt,
    hash::{Hash, Hasher},
    mem::size_of,
    num::NonZeroU64,
};

#[cfg(any(test, feature = "generator"))]
use bolero_generator::*;

const PACKET_SPACE_BITLEN: usize = 2;
const PACKET_SPACE_SHIFT: usize = (size_of::<PacketNumber>() * 8) - PACKET_SPACE_BITLEN;
const PACKET_NUMBER_MASK: u64 = core::u64::MAX >> PACKET_SPACE_BITLEN;

/// Contains a fully-decoded packet number in a given space
///
/// Internally the packet number is represented as a [`NonZeroU64`]
/// to ensure optimal memory layout.
///
/// The lower 62 bits are used to store the actual packet number value.
/// The upper 2 bits are used to store the packet number space. Because
/// there are only 3 spaces, the zero state is never used, which is why
/// [`NonZeroU64`] can be used instead of `u64`.
#[derive(Clone, Copy, Eq)]
#[cfg_attr(any(test, feature = "generator"), derive(TypeGenerator))]
pub struct PacketNumber(NonZeroU64);

impl IntoEvent<u64> for PacketNumber {
    #[inline]
    fn into_event(self) -> u64 {
        self.as_u64()
    }
}

impl Default for PacketNumber {
    fn default() -> Self {
        Self::from_varint(Default::default(), PacketNumberSpace::Initial)
    }
}

impl Hash for PacketNumber {
    #[inline]
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.0.hash(state)
    }
}

impl PartialEq for PacketNumber {
    #[inline]
    fn eq(&self, other: &Self) -> bool {
        self.cmp(other) == Ordering::Equal
    }
}

impl PartialOrd for PacketNumber {
    #[inline]
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

impl Ord for PacketNumber {
    #[inline]
    fn cmp(&self, other: &Self) -> Ordering {
        if cfg!(debug_assertions) {
            self.space().assert_eq(other.space());
        }
        self.0.cmp(&other.0)
    }
}

impl fmt::Debug for PacketNumber {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        f.debug_tuple("PacketNumber")
            .field(&self.space())
            .field(&self.as_u64())
            .finish()
    }
}

impl fmt::Display for PacketNumber {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        self.as_u64().fmt(f)
    }
}

impl PacketNumber {
    /// Creates a PacketNumber for a given VarInt and PacketNumberSpace
    #[inline]
    pub(crate) const fn from_varint(value: VarInt, space: PacketNumberSpace) -> Self {
        let tag = space.as_tag() as u64;
        let pn = (tag << PACKET_SPACE_SHIFT) | value.as_u64();
        let pn = unsafe {
            // Safety: packet number space tag is never 0
            NonZeroU64::new_unchecked(pn)
        };
        Self(pn)
    }

    /// Returns the `PacketNumberSpace` for the given `PacketNumber`
    #[inline]
    pub fn space(self) -> PacketNumberSpace {
        let tag = self.0.get() >> PACKET_SPACE_SHIFT;
        PacketNumberSpace::from_tag(tag as u8)
    }

    /// Converts the `PacketNumber` into a `VarInt` value.
    ///
    /// Note: Even though some scenarios require this function, it should be
    /// avoided in most cases, as it removes the corresponding `PacketNumberSpace`
    /// and allows math operations to be performed, which can easily result in
    /// protocol errors.
    #[allow(clippy::wrong_self_convention)] // Don't use `self` here to make conversion explicit
    pub const fn as_varint(packet_number: Self) -> VarInt {
        // Safety: when converting to a u64, we remove the top 2 bits which
        //         will force the value to fit into a VarInt.
        unsafe { VarInt::new_unchecked(packet_number.as_u64()) }
    }

    /// Truncates the `PacketNumber` into a `TruncatedPacketNumber` based on
    /// the largest acknowledged packet number
    #[inline]
    pub fn truncate(
        self,
        largest_acknowledged_packet_number: Self,
    ) -> Option<TruncatedPacketNumber> {
        Some(
            derive_truncation_range(largest_acknowledged_packet_number, self)?
                .truncate_packet_number(Self::as_varint(self)),
        )
    }

    /// Compute the next packet number in the space. If the packet number has
    /// exceeded the maximum value allowed `None` will be returned.
    #[inline]
    pub fn next(self) -> Option<Self> {
        let value = Self::as_varint(self).checked_add(VarInt::from_u8(1))?;
        let space = self.space();
        Some(Self::from_varint(value, space))
    }

    /// Compute the prev packet number in the space. If the packet number has
    /// underflowed `None` will be returned.
    #[inline]
    pub fn prev(self) -> Option<Self> {
        let value = Self::as_varint(self).checked_sub(VarInt::from_u8(1))?;
        let space = self.space();
        Some(Self::from_varint(value, space))
    }

    /// Create a nonce for crypto from the packet number value
    ///
    /// Note: This should not be used by anything other than crypto-related
    /// functionality.
    #[inline]
    pub const fn as_crypto_nonce(self) -> u64 {
        self.as_u64()
    }

    /// Returns the value with the top 2 bits removed
    #[inline]
    pub const fn as_u64(self) -> u64 {
        self.0.get() & PACKET_NUMBER_MASK
    }

    /// Computes the distance between this packet number and the given packet number,
    /// returning None if overflow occurred.
    #[inline]
    pub fn checked_distance(self, rhs: PacketNumber) -> Option<u64> {
        self.space().assert_eq(rhs.space());
        Self::as_u64(self).checked_sub(Self::as_u64(rhs))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Make sure the assumptions around packet number space tags holds true
    #[test]
    fn packet_number_space_assumptions_test() {
        assert!(PacketNumberSpace::Initial.as_tag() != 0);
        assert!(PacketNumberSpace::Handshake.as_tag() != 0);
        assert!(PacketNumberSpace::ApplicationData.as_tag() != 0);
    }

    #[test]
    fn round_trip_test() {
        let spaces = [
            PacketNumberSpace::Initial,
            PacketNumberSpace::Handshake,
            PacketNumberSpace::ApplicationData,
        ];

        let values = [
            VarInt::from_u8(0),
            VarInt::from_u8(1),
            VarInt::from_u8(2),
            VarInt::from_u8(core::u8::MAX / 2),
            VarInt::from_u8(core::u8::MAX - 1),
            VarInt::from_u8(core::u8::MAX),
            VarInt::from_u16(core::u16::MAX / 2),
            VarInt::from_u16(core::u16::MAX - 1),
            VarInt::from_u16(core::u16::MAX),
            VarInt::from_u32(core::u32::MAX / 2),
            VarInt::from_u32(core::u32::MAX - 1),
            VarInt::from_u32(core::u32::MAX),
            VarInt::MAX,
        ];

        for space in spaces.iter().cloned() {
            for value in values.iter().cloned() {
                let pn = PacketNumber::from_varint(value, space);
                assert_eq!(pn.space(), space, "{:#064b}", pn.0);
                assert_eq!(PacketNumber::as_varint(pn), value, "{:#064b}", pn.0);
            }
        }
    }
    #[test]
    #[should_panic]
    fn wrong_packet_number_space() {
        PacketNumberSpace::ApplicationData
            .new_packet_number(VarInt::from_u8(0))
            .checked_distance(PacketNumberSpace::Handshake.new_packet_number(VarInt::from_u8(0)));
    }
}