use super::alloc::vec::Vec;
use core::fmt;
#[cfg(feature = "std")]
use zeroize::{Zeroize, ZeroizeOnDrop};
#[cfg(feature = "arbitrary")]
use arbitrary::{Arbitrary, Unstructured};
#[cfg(feature = "serde")]
use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
use crate::{DeserializeBytes, Error, SerializeBytes, Size};
#[cfg(not(feature = "mls"))]
const MAX_LEN: u64 = (1 << 62) - 1;
#[cfg(not(feature = "mls"))]
const MAX_LEN_LEN_LOG: usize = 3;
#[cfg(feature = "mls")]
const MAX_LEN: u64 = (1 << 30) - 1;
#[cfg(feature = "mls")]
const MAX_LEN_LEN_LOG: usize = 2;
#[inline(always)]
fn check_min_length(length: usize, len_len: usize) -> Result<(), Error> {
if cfg!(feature = "mls") {
let min_len_len = length_encoding_bytes(length as u64)?;
if min_len_len != len_len {
return Err(Error::InvalidVectorLength);
}
};
Ok(())
}
#[inline(always)]
fn calculate_length(len_len_byte: u8) -> Result<(usize, usize), Error> {
let length: usize = (len_len_byte & 0x3F).into();
let len_len_log = (len_len_byte >> 6).into();
if !cfg!(fuzzing) {
debug_assert!(len_len_log <= MAX_LEN_LEN_LOG);
}
if len_len_log > MAX_LEN_LEN_LOG {
return Err(Error::InvalidVectorLength);
}
let len_len = match len_len_log {
0 => 1,
1 => 2,
2 => 4,
3 => 8,
_ => unreachable!(),
};
Ok((length, len_len))
}
#[inline(always)]
fn read_variable_length_bytes(bytes: &[u8]) -> Result<((usize, usize), &[u8]), Error> {
let (len_len_byte, mut remainder) = <u8 as DeserializeBytes>::tls_deserialize(bytes)?;
let (mut length, len_len) = calculate_length(len_len_byte)?;
for _ in 1..len_len {
let (next, next_remainder) = <u8 as DeserializeBytes>::tls_deserialize(remainder)?;
remainder = next_remainder;
length = (length << 8) + usize::from(next);
}
check_min_length(length, len_len)?;
Ok(((length, len_len), remainder))
}
#[inline(always)]
fn length_encoding_bytes(length: u64) -> Result<usize, Error> {
if !cfg!(fuzzing) {
debug_assert!(length <= MAX_LEN);
}
if length > MAX_LEN {
return Err(Error::InvalidVectorLength);
}
Ok(if length <= 0x3f {
1
} else if length <= 0x3fff {
2
} else if length <= 0x3fff_ffff {
4
} else {
8
})
}
#[inline(always)]
fn write_length(content_length: usize) -> Result<Vec<u8>, Error> {
let len_len = length_encoding_bytes(content_length.try_into()?)?;
if !cfg!(fuzzing) {
debug_assert!(len_len <= 8, "Invalid vector len_len {len_len}");
}
if len_len > 8 {
return Err(Error::LibraryError);
}
let mut length_bytes = vec![0u8; len_len];
match len_len {
1 => length_bytes[0] = 0x00,
2 => length_bytes[0] = 0x40,
4 => length_bytes[0] = 0x80,
8 => length_bytes[0] = 0xc0,
_ => {
if !cfg!(fuzzing) {
debug_assert!(false, "Invalid vector len_len {len_len}");
}
return Err(Error::InvalidVectorLength);
}
}
let mut len = content_length;
for b in length_bytes.iter_mut().rev() {
*b |= (len & 0xFF) as u8;
len >>= 8;
}
Ok(length_bytes)
}
impl<T: Size> Size for Vec<T> {
#[inline(always)]
fn tls_serialized_len(&self) -> usize {
self.as_slice().tls_serialized_len()
}
}
impl<T: Size> Size for &Vec<T> {
#[inline(always)]
fn tls_serialized_len(&self) -> usize {
(*self).tls_serialized_len()
}
}
impl<T: DeserializeBytes> DeserializeBytes for Vec<T> {
#[inline(always)]
fn tls_deserialize(bytes: &[u8]) -> Result<(Self, &[u8]), Error> {
let ((length, len_len), mut remainder) = read_variable_length_bytes(bytes)?;
if length == 0 {
return Ok((Vec::new(), remainder));
}
let mut result = Vec::new();
let mut read = len_len;
while (read - len_len) < length {
let (element, next_remainder) = T::tls_deserialize(remainder)?;
remainder = next_remainder;
read += element.tls_serialized_len();
result.push(element);
}
Ok((result, remainder))
}
}
impl<T: SerializeBytes> SerializeBytes for &[T] {
#[inline(always)]
fn tls_serialize(&self) -> Result<Vec<u8>, Error> {
let content_length = self.iter().fold(0, |acc, e| acc + e.tls_serialized_len());
let mut length = write_length(content_length)?;
let len_len = length.len();
let mut out = Vec::with_capacity(content_length + len_len);
out.append(&mut length);
for e in self.iter() {
out.append(&mut e.tls_serialize()?);
}
#[cfg(debug_assertions)]
if out.len() - len_len != content_length {
return Err(Error::LibraryError);
}
Ok(out)
}
}
impl<T: SerializeBytes> SerializeBytes for &Vec<T> {
#[inline(always)]
fn tls_serialize(&self) -> Result<Vec<u8>, Error> {
self.as_slice().tls_serialize()
}
}
impl<T: SerializeBytes> SerializeBytes for Vec<T> {
fn tls_serialize(&self) -> Result<Vec<u8>, Error> {
self.as_slice().tls_serialize()
}
}
impl<T: Size> Size for &[T] {
#[inline(always)]
fn tls_serialized_len(&self) -> usize {
let content_length = self.iter().fold(0, |acc, e| acc + e.tls_serialized_len());
let len_len = length_encoding_bytes(content_length as u64).unwrap_or({
0
});
content_length + len_len
}
}
fn write_hex(f: &mut fmt::Formatter<'_>, data: &[u8]) -> fmt::Result {
if !data.is_empty() {
write!(f, "0x")?;
for byte in data {
write!(f, "{byte:02x}")?;
}
} else {
write!(f, "b\"\"")?;
}
Ok(())
}
macro_rules! impl_vl_bytes_generic {
($name:ident) => {
impl fmt::Debug for $name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} {{ ", stringify!($name))?;
write_hex(f, &self.vec())?;
write!(f, " }}")
}
}
impl $name {
pub fn as_slice(&self) -> &[u8] {
self.vec().as_ref()
}
#[inline]
pub fn push(&mut self, value: u8) {
self.vec_mut().push(value);
}
#[inline]
pub fn pop(&mut self) -> Option<u8> {
self.vec_mut().pop()
}
}
impl From<Vec<u8>> for $name {
fn from(vec: Vec<u8>) -> Self {
Self::new(vec)
}
}
impl From<&[u8]> for $name {
fn from(slice: &[u8]) -> Self {
Self::new(slice.to_vec())
}
}
impl<const N: usize> From<&[u8; N]> for $name {
fn from(slice: &[u8; N]) -> Self {
Self::new(slice.to_vec())
}
}
impl AsRef<[u8]> for $name {
fn as_ref(&self) -> &[u8] {
&self.vec()
}
}
};
}
#[cfg_attr(feature = "serde", derive(SerdeSerialize, SerdeDeserialize))]
#[cfg_attr(feature = "std", derive(Zeroize))]
#[derive(Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub struct VLBytes {
vec: Vec<u8>,
}
impl VLBytes {
pub fn new(vec: Vec<u8>) -> Self {
Self { vec }
}
fn vec(&self) -> &[u8] {
&self.vec
}
fn vec_mut(&mut self) -> &mut Vec<u8> {
&mut self.vec
}
}
impl_vl_bytes_generic!(VLBytes);
impl From<VLBytes> for Vec<u8> {
fn from(b: VLBytes) -> Self {
b.vec
}
}
#[inline(always)]
fn tls_serialize_bytes_len(bytes: &[u8]) -> usize {
let content_length = bytes.len();
let len_len = length_encoding_bytes(content_length as u64).unwrap_or({
0
});
content_length + len_len
}
impl Size for VLBytes {
#[inline(always)]
fn tls_serialized_len(&self) -> usize {
tls_serialize_bytes_len(self.as_slice())
}
}
impl DeserializeBytes for VLBytes {
#[inline(always)]
fn tls_deserialize(bytes: &[u8]) -> Result<(Self, &[u8]), Error> {
let ((length, _), remainder) = read_variable_length_bytes(bytes)?;
if length == 0 {
return Ok((Self::new(vec![]), remainder));
}
if !cfg!(fuzzing) {
debug_assert!(
length <= MAX_LEN as usize,
"Trying to allocate {length} bytes. Only {MAX_LEN} allowed.",
);
}
if length > MAX_LEN as usize {
return Err(Error::DecodingError(format!(
"Trying to allocate {length} bytes. Only {MAX_LEN} allowed.",
)));
}
match remainder.get(..length).ok_or(Error::EndOfStream) {
Ok(vec) => Ok((Self { vec: vec.to_vec() }, &remainder[length..])),
Err(_e) => {
let remaining_len = remainder.len();
if !cfg!(fuzzing) {
debug_assert_eq!(
remaining_len, length,
"Expected to read {length} bytes but {remaining_len} were read.",
);
}
Err(Error::DecodingError(format!(
"{remaining_len} bytes were read but {length} were expected",
)))
}
}
}
}
impl Size for &VLBytes {
#[inline(always)]
fn tls_serialized_len(&self) -> usize {
(*self).tls_serialized_len()
}
}
pub struct VLByteSlice<'a>(pub &'a [u8]);
impl<'a> fmt::Debug for VLByteSlice<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "VLByteSlice {{ ")?;
write_hex(f, self.0)?;
write!(f, " }}")
}
}
impl<'a> VLByteSlice<'a> {
#[inline(always)]
pub fn as_slice(&self) -> &[u8] {
self.0
}
}
impl<'a> Size for &VLByteSlice<'a> {
#[inline]
fn tls_serialized_len(&self) -> usize {
tls_serialize_bytes_len(self.0)
}
}
impl<'a> Size for VLByteSlice<'a> {
#[inline]
fn tls_serialized_len(&self) -> usize {
tls_serialize_bytes_len(self.0)
}
}
#[cfg(feature = "std")]
mod rw {
use super::*;
use crate::{Deserialize, Serialize};
#[inline]
pub(super) fn read_variable_length<R: std::io::Read>(
bytes: &mut R,
) -> Result<(usize, usize), Error> {
let mut len_len_byte = [0u8; 1];
if bytes.read(&mut len_len_byte)? == 0 {
return Err(Error::InvalidVectorLength);
}
let len_len_byte = len_len_byte[0];
let (mut length, len_len) = calculate_length(len_len_byte)?;
for _ in 1..len_len {
let mut next = [0u8; 1];
bytes.read_exact(&mut next)?;
length = (length << 8) + usize::from(next[0]);
}
check_min_length(length, len_len)?;
Ok((length, len_len))
}
impl<T: Deserialize> Deserialize for Vec<T> {
#[inline(always)]
fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, Error> {
let (length, len_len) = read_variable_length(bytes)?;
if length == 0 {
return Ok(Vec::new());
}
let mut result = Vec::new();
let mut read = len_len;
while (read - len_len) < length {
let element = T::tls_deserialize(bytes)?;
read += element.tls_serialized_len();
result.push(element);
}
Ok(result)
}
}
#[inline(always)]
pub(super) fn write_length<W: std::io::Write>(
writer: &mut W,
content_length: usize,
) -> Result<usize, Error> {
let buf = super::write_length(content_length)?;
let buf_len = buf.len();
writer.write_all(&buf)?;
Ok(buf_len)
}
impl<T: Serialize + std::fmt::Debug> Serialize for Vec<T> {
#[inline(always)]
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> {
self.as_slice().tls_serialize(writer)
}
}
impl<T: Serialize + std::fmt::Debug> Serialize for &[T] {
#[inline(always)]
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> {
let content_length = self.iter().fold(0, |acc, e| acc + e.tls_serialized_len());
let len_len = write_length(writer, content_length)?;
#[cfg(debug_assertions)]
let mut written = 0;
for e in self.iter() {
#[cfg(debug_assertions)]
{
written += e.tls_serialize(writer)?;
}
#[cfg(not(debug_assertions))]
e.tls_serialize(writer)?;
}
#[cfg(debug_assertions)]
if written != content_length {
return Err(Error::LibraryError);
}
Ok(content_length + len_len)
}
}
}
#[cfg(feature = "std")]
use rw::*;
#[cfg(feature = "std")]
mod rw_bytes {
use super::*;
use crate::{Deserialize, Serialize};
#[inline(always)]
fn tls_serialize_bytes<W: std::io::Write>(
writer: &mut W,
bytes: &[u8],
) -> Result<usize, Error> {
let content_length = bytes.len();
if !cfg!(fuzzing) {
debug_assert!(
content_length as u64 <= MAX_LEN,
"Vector can't be encoded. It's too large. {content_length} >= {MAX_LEN}",
);
}
if content_length as u64 > MAX_LEN {
return Err(Error::InvalidVectorLength);
}
let length_bytes = write_length(content_length)?;
let len_len = length_bytes.len();
writer.write_all(&length_bytes)?;
let mut written = 0;
written += writer.write(bytes)?;
if !cfg!(fuzzing) {
debug_assert_eq!(
written, content_length,
"{content_length} bytes should have been serialized but {written} were written",
);
}
if written != content_length {
return Err(Error::EncodingError(format!(
"{content_length} bytes should have been serialized but {written} were written",
)));
}
Ok(written + len_len)
}
impl Serialize for VLBytes {
#[inline(always)]
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> {
tls_serialize_bytes(writer, self.as_slice())
}
}
impl Serialize for &VLBytes {
#[inline(always)]
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> {
(*self).tls_serialize(writer)
}
}
impl Deserialize for VLBytes {
fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, Error> {
let (length, _) = read_variable_length(bytes)?;
if length == 0 {
return Ok(Self::new(vec![]));
}
if !cfg!(fuzzing) {
debug_assert!(
length <= MAX_LEN as usize,
"Trying to allocate {length} bytes. Only {MAX_LEN} allowed.",
);
}
if length > MAX_LEN as usize {
return Err(Error::DecodingError(format!(
"Trying to allocate {length} bytes. Only {MAX_LEN} allowed.",
)));
}
let mut result = Self {
vec: vec![0u8; length],
};
let read = bytes.read(result.vec.as_mut_slice())?;
if read == length {
return Ok(result);
}
if !cfg!(fuzzing) {
debug_assert_eq!(
read, length,
"Expected to read {length} bytes but {read} were read.",
);
}
Err(Error::DecodingError(format!(
"{read} bytes were read but {length} were expected",
)))
}
}
impl<'a> Serialize for &VLByteSlice<'a> {
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> {
tls_serialize_bytes(writer, self.0)
}
}
impl<'a> Serialize for VLByteSlice<'a> {
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> {
tls_serialize_bytes(writer, self.0)
}
}
}
#[cfg(feature = "std")]
mod secret_bytes {
use super::*;
use crate::{Deserialize, Serialize};
#[cfg_attr(feature = "serde", derive(SerdeSerialize, SerdeDeserialize))]
#[derive(Clone, PartialEq, Eq, Hash, Ord, PartialOrd, Zeroize, ZeroizeOnDrop)]
pub struct SecretVLBytes(VLBytes);
impl SecretVLBytes {
pub fn new(vec: Vec<u8>) -> Self {
Self(VLBytes { vec })
}
fn vec(&self) -> &[u8] {
&self.0.vec
}
fn vec_mut(&mut self) -> &mut Vec<u8> {
&mut self.0.vec
}
}
impl_vl_bytes_generic!(SecretVLBytes);
impl Size for SecretVLBytes {
fn tls_serialized_len(&self) -> usize {
self.0.tls_serialized_len()
}
}
impl DeserializeBytes for SecretVLBytes {
fn tls_deserialize(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
where
Self: Sized,
{
let (bytes, remainder) = <VLBytes as DeserializeBytes>::tls_deserialize(bytes)?;
Ok((Self(bytes), remainder))
}
}
impl Serialize for SecretVLBytes {
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> {
self.0.tls_serialize(writer)
}
}
impl Deserialize for SecretVLBytes {
fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, Error>
where
Self: Sized,
{
Ok(Self(<VLBytes as Deserialize>::tls_deserialize(bytes)?))
}
}
}
#[cfg(feature = "std")]
pub use secret_bytes::SecretVLBytes;
#[cfg(feature = "arbitrary")]
impl<'a> Arbitrary<'a> for VLBytes {
fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
let mut vec = Vec::arbitrary(u)?;
vec.truncate(MAX_LEN as usize);
Ok(Self { vec })
}
}
#[cfg(feature = "std")]
#[cfg(test)]
mod test {
use crate::{SecretVLBytes, VLByteSlice, VLBytes};
use std::println;
#[test]
fn test_debug() {
let tests = [
(vec![], "b\"\""),
(vec![0x00], "0x00"),
(vec![0xAA], "0xaa"),
(vec![0xFF], "0xff"),
(vec![0x00, 0x00], "0x0000"),
(vec![0x00, 0xAA], "0x00aa"),
(vec![0x00, 0xFF], "0x00ff"),
(vec![0xff, 0xff], "0xffff"),
];
for (test, expected) in tests.into_iter() {
println!("\n# {test:?}");
let expected_vl_byte_slice = format!("VLByteSlice {{ {expected} }}");
let got = format!("{:?}", VLByteSlice(&test));
println!("{got}");
assert_eq!(expected_vl_byte_slice, got);
let expected_vl_bytes = format!("VLBytes {{ {expected} }}");
let got = format!("{:?}", VLBytes::new(test.clone()));
println!("{got}");
assert_eq!(expected_vl_bytes, got);
let expected_secret_vl_bytes = format!("SecretVLBytes {{ {expected} }}");
let got = format!("{:?}", SecretVLBytes::new(test.clone()));
println!("{got}");
assert_eq!(expected_secret_vl_bytes, got);
}
}
}