use crate::packet::{
initial::ProtectedInitial,
long::{
validate_destination_connection_id_len, validate_source_connection_id_len,
DestinationConnectionIdLen, SourceConnectionIdLen, Version,
},
Tag,
};
use core::mem::size_of;
use s2n_codec::{
decoder_invariant, DecoderBuffer, DecoderBufferMut, DecoderBufferMutResult, Encoder,
EncoderValue,
};
macro_rules! version_negotiation_no_fixed_bit_tag {
() => {
0b1000u8..=0b1011u8
};
}
const ENCODING_TAG: u8 = 0b1100_0000;
pub(crate) const VERSION: u32 = 0x0000_0000;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct VersionNegotiation<'a, SupportedVersions> {
pub tag: Tag,
pub destination_connection_id: &'a [u8],
pub source_connection_id: &'a [u8],
pub supported_versions: SupportedVersions,
}
pub type ProtectedVersionNegotiation<'a> = VersionNegotiation<'a, &'a [u8]>;
pub type EncryptedVersionNegotiation<'a> = VersionNegotiation<'a, &'a [u8]>;
pub type CleartextVersionNegotiation<'a> = VersionNegotiation<'a, &'a [u8]>;
impl<'a> ProtectedVersionNegotiation<'a> {
#[inline]
pub fn decode(
tag: Tag,
_version: Version,
buffer: DecoderBufferMut,
) -> DecoderBufferMutResult<VersionNegotiation<&[u8]>> {
let buffer = buffer
.skip(size_of::<Tag>() + size_of::<Version>())
.expect("tag and version already verified");
let (destination_connection_id, buffer) =
buffer.decode_slice_with_len_prefix::<DestinationConnectionIdLen>()?;
let destination_connection_id = destination_connection_id.into_less_safe_slice();
validate_destination_connection_id_len(destination_connection_id.len())?;
let (source_connection_id, buffer) =
buffer.decode_slice_with_len_prefix::<SourceConnectionIdLen>()?;
let source_connection_id = source_connection_id.into_less_safe_slice();
validate_source_connection_id_len(source_connection_id.len())?;
let (supported_versions, buffer) = buffer.decode::<DecoderBufferMut>()?;
let supported_versions: &[u8] = supported_versions.into_less_safe_slice();
decoder_invariant!(
supported_versions.len() >= size_of::<u32>(),
"missing at least one version"
);
decoder_invariant!(
supported_versions.len() % size_of::<u32>() == 0,
"invalid payload length"
);
let packet = VersionNegotiation {
tag,
destination_connection_id,
source_connection_id,
supported_versions,
};
Ok((packet, buffer))
}
#[inline]
pub fn iter(&'a self) -> VersionNegotiationIterator<'a> {
self.into_iter()
}
#[inline]
pub fn destination_connection_id(&self) -> &[u8] {
self.destination_connection_id
}
#[inline]
pub fn source_connection_id(&self) -> &[u8] {
self.source_connection_id
}
}
impl<'a, SupportedVersions: EncoderValue> VersionNegotiation<'a, SupportedVersions> {
pub fn from_initial(
initial_packet: &'a ProtectedInitial,
supported_versions: SupportedVersions,
) -> Self {
Self {
tag: 0,
destination_connection_id: initial_packet.source_connection_id(),
source_connection_id: initial_packet.destination_connection_id(),
supported_versions,
}
}
}
impl<'a> IntoIterator for ProtectedVersionNegotiation<'a> {
type IntoIter = VersionNegotiationIterator<'a>;
type Item = u32;
fn into_iter(self) -> Self::IntoIter {
VersionNegotiationIterator(DecoderBuffer::new(self.supported_versions))
}
}
#[derive(Clone, Copy, Debug)]
pub struct VersionNegotiationIterator<'a>(DecoderBuffer<'a>);
impl<'a> Iterator for VersionNegotiationIterator<'a> {
type Item = u32;
fn next(&mut self) -> Option<Self::Item> {
if let Ok((value, buffer)) = self.0.decode() {
self.0 = buffer;
Some(value)
} else {
None
}
}
}
impl<'a, SupportedVersions: EncoderValue> EncoderValue
for VersionNegotiation<'a, SupportedVersions>
{
fn encode<E: Encoder>(&self, encoder: &mut E) {
(self.tag | ENCODING_TAG).encode(encoder);
VERSION.encode(encoder);
self.destination_connection_id
.encode_with_len_prefix::<DestinationConnectionIdLen, _>(encoder);
self.source_connection_id
.encode_with_len_prefix::<SourceConnectionIdLen, _>(encoder);
self.supported_versions.encode(encoder);
}
}