use crate::{
connection,
connection::{id::ConnectionInfo, ProcessingError},
crypto::{packet_protection, EncryptedPayload, OneRttHeaderKey, OneRttKey, ProtectedPayload},
packet::{
decoding::HeaderDecoder,
encoding::{PacketEncoder, PacketPayloadEncoder},
number::{
PacketNumber, PacketNumberLen, PacketNumberSpace, ProtectedPacketNumber,
TruncatedPacketNumber,
},
KeyPhase, ProtectedKeyPhase, Tag,
},
transport,
};
use s2n_codec::{CheckedRange, DecoderBufferMut, DecoderBufferMutResult, Encoder, EncoderValue};
macro_rules! short_tag {
() => {
0b0100u8..=0b0111u8
};
}
const ENCODING_TAG: u8 = 0b0100_0000;
const SPIN_BIT_MASK: u8 = 0x20;
const RESERVED_BITS_MASK: u8 = 0x18;
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum SpinBit {
Zero,
One,
}
impl Default for SpinBit {
fn default() -> Self {
Self::Zero
}
}
impl SpinBit {
fn from_tag(tag: Tag) -> Self {
if tag & SPIN_BIT_MASK == SPIN_BIT_MASK {
Self::One
} else {
Self::Zero
}
}
fn into_packet_tag_mask(self) -> u8 {
match self {
Self::One => SPIN_BIT_MASK,
Self::Zero => 0,
}
}
}
#[derive(Debug)]
pub struct Short<DCID, KeyPhase, PacketNumber, Payload> {
pub spin_bit: SpinBit,
pub key_phase: KeyPhase,
pub destination_connection_id: DCID,
pub packet_number: PacketNumber,
pub payload: Payload,
}
pub type ProtectedShort<'a> =
Short<CheckedRange, ProtectedKeyPhase, ProtectedPacketNumber, ProtectedPayload<'a>>;
pub type EncryptedShort<'a> = Short<CheckedRange, KeyPhase, PacketNumber, EncryptedPayload<'a>>;
pub type CleartextShort<'a> = Short<&'a [u8], KeyPhase, PacketNumber, DecoderBufferMut<'a>>;
impl<'a> ProtectedShort<'a> {
#[inline]
pub(crate) fn decode<Validator: connection::id::Validator>(
tag: Tag,
buffer: DecoderBufferMut<'a>,
connection_info: &ConnectionInfo,
destination_connection_id_decoder: &Validator,
) -> DecoderBufferMutResult<'a, ProtectedShort<'a>> {
let mut decoder = HeaderDecoder::new_short(&buffer);
let spin_bit = SpinBit::from_tag(tag);
let key_phase = ProtectedKeyPhase;
let destination_connection_id = decoder.decode_short_destination_connection_id(
&buffer,
connection_info,
destination_connection_id_decoder,
)?;
let (payload, packet_number, remaining) =
decoder.finish_short()?.split_off_packet(buffer)?;
let packet = Short {
spin_bit,
key_phase,
destination_connection_id,
packet_number,
payload,
};
Ok((packet, remaining))
}
pub fn unprotect<H: OneRttHeaderKey>(
self,
header_key: &H,
largest_acknowledged_packet_number: PacketNumber,
) -> Result<EncryptedShort<'a>, packet_protection::Error> {
let Short {
spin_bit,
destination_connection_id,
payload,
..
} = self;
let (truncated_packet_number, payload) =
crate::crypto::unprotect(header_key, PacketNumberSpace::ApplicationData, payload)?;
let key_phase = KeyPhase::from_tag(payload.get_tag());
let packet_number = truncated_packet_number.expand(largest_acknowledged_packet_number);
Ok(Short {
spin_bit,
key_phase,
destination_connection_id,
packet_number,
payload,
})
}
#[inline]
pub fn destination_connection_id(&self) -> &[u8] {
self.payload
.get_checked_range(&self.destination_connection_id)
.into_less_safe_slice()
}
}
impl<'a> EncryptedShort<'a> {
pub fn decrypt<C: OneRttKey>(self, crypto: &C) -> Result<CleartextShort<'a>, ProcessingError> {
let Short {
spin_bit,
key_phase,
destination_connection_id,
packet_number,
payload,
} = self;
let (header, payload) = crate::crypto::decrypt(crypto, packet_number, payload)?;
let header = header.into_less_safe_slice();
if header[0] & RESERVED_BITS_MASK != 0 {
return Err(transport::Error::PROTOCOL_VIOLATION
.with_reason("reserved bits are non-zero")
.into());
}
let destination_connection_id = destination_connection_id.get(header);
Ok(Short {
spin_bit,
key_phase,
destination_connection_id,
packet_number,
payload,
})
}
#[inline]
pub fn key_phase(&self) -> KeyPhase {
self.key_phase
}
#[inline]
pub fn destination_connection_id(&self) -> &[u8] {
self.payload
.get_checked_range(&self.destination_connection_id)
.into_less_safe_slice()
}
}
impl<'a> CleartextShort<'a> {
#[inline]
pub fn destination_connection_id(&self) -> &[u8] {
self.destination_connection_id
}
}
impl<DCID: EncoderValue, Payload: EncoderValue> EncoderValue
for Short<DCID, KeyPhase, TruncatedPacketNumber, Payload>
{
#[inline]
fn encode<E: Encoder>(&self, encoder: &mut E) {
self.encode_header(self.packet_number.len(), encoder);
self.packet_number.encode(encoder);
self.payload.encode(encoder);
}
}
impl<DCID: EncoderValue, PacketNumber, Payload> Short<DCID, KeyPhase, PacketNumber, Payload> {
#[inline]
fn encode_header<E: Encoder>(&self, packet_number_len: PacketNumberLen, encoder: &mut E) {
(ENCODING_TAG
| self.spin_bit.into_packet_tag_mask()
| self.key_phase.into_packet_tag_mask()
| packet_number_len.into_packet_tag_mask())
.encode(encoder);
self.destination_connection_id.encode(encoder);
}
}
impl<DCID: EncoderValue, Payload: PacketPayloadEncoder, K: OneRttKey, H: OneRttHeaderKey>
PacketEncoder<K, H, Payload> for Short<DCID, KeyPhase, PacketNumber, Payload>
{
type PayloadLenCursor = ();
#[inline]
fn packet_number(&self) -> PacketNumber {
self.packet_number
}
#[inline]
fn encode_header<E: Encoder>(&self, packet_number_len: PacketNumberLen, encoder: &mut E) {
Short::encode_header(self, packet_number_len, encoder);
}
#[inline]
fn payload(&mut self) -> &mut Payload {
&mut self.payload
}
}