use crate::{
event::IntoEvent,
packet::number::{
derive_truncation_range, packet_number_space::PacketNumberSpace,
truncated_packet_number::TruncatedPacketNumber,
},
varint::VarInt,
};
use core::{
cmp::Ordering,
fmt,
hash::{Hash, Hasher},
mem::size_of,
num::NonZeroU64,
};
#[cfg(any(test, feature = "generator"))]
use bolero_generator::*;
const PACKET_SPACE_BITLEN: usize = 2;
const PACKET_SPACE_SHIFT: usize = (size_of::<PacketNumber>() * 8) - PACKET_SPACE_BITLEN;
const PACKET_NUMBER_MASK: u64 = u64::MAX >> PACKET_SPACE_BITLEN;
#[derive(Clone, Copy, Eq)]
#[cfg_attr(any(test, feature = "generator"), derive(TypeGenerator))]
pub struct PacketNumber(NonZeroU64);
impl IntoEvent<u64> for PacketNumber {
#[inline]
fn into_event(self) -> u64 {
self.as_u64()
}
}
impl Default for PacketNumber {
fn default() -> Self {
Self::from_varint(Default::default(), PacketNumberSpace::Initial)
}
}
impl Hash for PacketNumber {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state)
}
}
impl PartialEq for PacketNumber {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl PartialOrd for PacketNumber {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PacketNumber {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
if cfg!(debug_assertions) {
self.space().assert_eq(other.space());
}
self.0.cmp(&other.0)
}
}
impl fmt::Debug for PacketNumber {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("PacketNumber")
.field(&self.space())
.field(&self.as_u64())
.finish()
}
}
impl fmt::Display for PacketNumber {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.as_u64().fmt(f)
}
}
impl PacketNumber {
#[inline]
pub(crate) const fn from_varint(value: VarInt, space: PacketNumberSpace) -> Self {
let tag = space.as_tag() as u64;
let pn = (tag << PACKET_SPACE_SHIFT) | value.as_u64();
let pn = unsafe {
NonZeroU64::new_unchecked(pn)
};
Self(pn)
}
#[inline]
pub fn space(self) -> PacketNumberSpace {
let tag = self.0.get() >> PACKET_SPACE_SHIFT;
PacketNumberSpace::from_tag(tag as u8)
}
#[allow(clippy::wrong_self_convention)] pub const fn as_varint(packet_number: Self) -> VarInt {
unsafe { VarInt::new_unchecked(packet_number.as_u64()) }
}
#[inline]
pub fn truncate(
self,
largest_acknowledged_packet_number: Self,
) -> Option<TruncatedPacketNumber> {
Some(
derive_truncation_range(largest_acknowledged_packet_number, self)?
.truncate_packet_number(Self::as_varint(self)),
)
}
#[inline]
pub fn next(self) -> Option<Self> {
let value = Self::as_varint(self).checked_add(VarInt::from_u8(1))?;
let space = self.space();
Some(Self::from_varint(value, space))
}
#[inline]
pub fn prev(self) -> Option<Self> {
let value = Self::as_varint(self).checked_sub(VarInt::from_u8(1))?;
let space = self.space();
Some(Self::from_varint(value, space))
}
#[inline]
pub const fn as_crypto_nonce(self) -> u64 {
self.as_u64()
}
#[inline]
pub const fn as_u64(self) -> u64 {
self.0.get() & PACKET_NUMBER_MASK
}
#[inline]
pub fn checked_distance(self, rhs: PacketNumber) -> Option<u64> {
self.space().assert_eq(rhs.space());
Self::as_u64(self).checked_sub(Self::as_u64(rhs))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn packet_number_space_assumptions_test() {
assert!(PacketNumberSpace::Initial.as_tag() != 0);
assert!(PacketNumberSpace::Handshake.as_tag() != 0);
assert!(PacketNumberSpace::ApplicationData.as_tag() != 0);
}
#[test]
fn round_trip_test() {
let spaces = [
PacketNumberSpace::Initial,
PacketNumberSpace::Handshake,
PacketNumberSpace::ApplicationData,
];
let values = [
VarInt::from_u8(0),
VarInt::from_u8(1),
VarInt::from_u8(2),
VarInt::from_u8(u8::MAX / 2),
VarInt::from_u8(u8::MAX - 1),
VarInt::from_u8(u8::MAX),
VarInt::from_u16(u16::MAX / 2),
VarInt::from_u16(u16::MAX - 1),
VarInt::from_u16(u16::MAX),
VarInt::from_u32(u32::MAX / 2),
VarInt::from_u32(u32::MAX - 1),
VarInt::from_u32(u32::MAX),
VarInt::MAX,
];
for space in spaces.iter().cloned() {
for value in values.iter().cloned() {
let pn = PacketNumber::from_varint(value, space);
assert_eq!(pn.space(), space, "{:#064b}", pn.0);
assert_eq!(PacketNumber::as_varint(pn), value, "{:#064b}", pn.0);
}
}
}
#[test]
#[should_panic]
fn wrong_packet_number_space() {
PacketNumberSpace::ApplicationData
.new_packet_number(VarInt::from_u8(0))
.checked_distance(PacketNumberSpace::Handshake.new_packet_number(VarInt::from_u8(0)));
}
}