use crate::{
crypto::{EncryptedPayload, ProtectedPayload},
packet::number::{PacketNumberSpace, TruncatedPacketNumber},
};
use s2n_codec::{DecoderBuffer, DecoderError};
pub trait HeaderKey: Send {
fn opening_header_protection_mask(&self, ciphertext_sample: &[u8]) -> HeaderProtectionMask;
fn opening_sample_len(&self) -> usize;
fn sealing_header_protection_mask(&self, ciphertext_sample: &[u8]) -> HeaderProtectionMask;
fn sealing_sample_len(&self) -> usize;
}
pub const HEADER_PROTECTION_MASK_LEN: usize = 5;
pub type HeaderProtectionMask = [u8; HEADER_PROTECTION_MASK_LEN];
const LONG_HEADER_TAG: u8 = 0x80;
pub(crate) const LONG_HEADER_MASK: u8 = 0x0f;
pub(crate) const SHORT_HEADER_MASK: u8 = 0x1f;
#[inline(always)]
fn mask_from_packet_tag(tag: u8) -> u8 {
if tag & LONG_HEADER_TAG == LONG_HEADER_TAG {
LONG_HEADER_MASK
} else {
SHORT_HEADER_MASK
}
}
#[inline(always)]
fn xor_mask(payload: &mut [u8], mask: &[u8]) {
for (payload_byte, mask_byte) in payload.iter_mut().zip(&mask[1..]) {
*payload_byte ^= mask_byte;
}
}
#[inline]
pub(crate) fn apply_header_protection(
mask: HeaderProtectionMask,
payload: EncryptedPayload,
) -> ProtectedPayload {
let header_len = payload.header_len;
let packet_number_len = payload.packet_number_len;
let payload = payload.buffer.into_less_safe_slice();
payload[0] ^= mask[0] & mask_from_packet_tag(payload[0]);
let header_with_pn_len = packet_number_len.bytesize() + header_len;
let packet_number_bytes = &mut payload[header_len..header_with_pn_len];
xor_mask(packet_number_bytes, &mask);
ProtectedPayload::new(header_len, payload)
}
#[inline]
pub(crate) fn remove_header_protection(
space: PacketNumberSpace,
mask: HeaderProtectionMask,
payload: ProtectedPayload,
) -> Result<(TruncatedPacketNumber, EncryptedPayload), DecoderError> {
let header_len = payload.header_len;
let payload = payload.buffer.into_less_safe_slice();
payload[0] ^= mask[0] & mask_from_packet_tag(payload[0]);
let packet_number_len = space.new_packet_number_len(payload[0]);
let header_with_pn_len = packet_number_len.bytesize() + header_len;
let packet_number = {
let packet_number_bytes = &mut payload[header_len..header_with_pn_len];
xor_mask(packet_number_bytes, &mask);
let (packet_number, _) = packet_number_len
.decode_truncated_packet_number(DecoderBuffer::new(packet_number_bytes))?;
packet_number
};
Ok((
packet_number,
EncryptedPayload::new(header_len, packet_number_len, payload),
))
}