use {
crate::error::TlvError,
bytemuck::{Pod, Zeroable},
solana_program::program_error::ProgramError,
};
pub fn pod_from_bytes<T: Pod>(bytes: &[u8]) -> Result<&T, ProgramError> {
bytemuck::try_from_bytes(bytes).map_err(|_| ProgramError::InvalidArgument)
}
pub fn pod_from_bytes_mut<T: Pod>(bytes: &mut [u8]) -> Result<&mut T, ProgramError> {
bytemuck::try_from_bytes_mut(bytes).map_err(|_| ProgramError::InvalidArgument)
}
pub fn pod_slice_from_bytes<T: Pod>(bytes: &[u8]) -> Result<&[T], ProgramError> {
bytemuck::try_cast_slice(bytes).map_err(|_| ProgramError::InvalidArgument)
}
pub fn pod_slice_from_bytes_mut<T: Pod>(bytes: &mut [u8]) -> Result<&mut [T], ProgramError> {
bytemuck::try_cast_slice_mut(bytes).map_err(|_| ProgramError::InvalidArgument)
}
macro_rules! impl_int_conversion {
($P:ty, $I:ty) => {
impl From<$I> for $P {
fn from(n: $I) -> Self {
Self(n.to_le_bytes())
}
}
impl From<$P> for $I {
fn from(pod: $P) -> Self {
Self::from_le_bytes(pod.0)
}
}
};
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
#[repr(transparent)]
pub struct PodU32([u8; 4]);
impl_int_conversion!(PodU32, u32);
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
#[repr(transparent)]
pub struct PodBool(u8);
impl From<bool> for PodBool {
fn from(b: bool) -> Self {
Self(if b { 1 } else { 0 })
}
}
impl From<&PodBool> for bool {
fn from(b: &PodBool) -> Self {
b.0 != 0
}
}
impl From<PodBool> for bool {
fn from(b: PodBool) -> Self {
b.0 != 0
}
}
const LENGTH_SIZE: usize = std::mem::size_of::<PodU32>();
pub struct PodSlice<'data, T: Pod> {
length: &'data PodU32,
data: &'data [T],
}
impl<'data, T: Pod> PodSlice<'data, T> {
pub fn unpack<'a>(data: &'a [u8]) -> Result<Self, ProgramError>
where
'a: 'data,
{
if data.len() < LENGTH_SIZE {
return Err(TlvError::BufferTooSmall.into());
}
let (length, data) = data.split_at(LENGTH_SIZE);
let length = pod_from_bytes::<PodU32>(length)?;
let _max_length = max_len_for_type::<T>(data.len())?;
let data = pod_slice_from_bytes(data)?;
Ok(Self { length, data })
}
pub fn data(&self) -> &[T] {
let length = u32::from(*self.length) as usize;
&self.data[..length]
}
pub fn size_of(num_items: usize) -> Result<usize, ProgramError> {
std::mem::size_of::<T>()
.checked_mul(num_items)
.and_then(|len| len.checked_add(LENGTH_SIZE))
.ok_or_else(|| TlvError::CalculationFailure.into())
}
}
pub struct PodSliceMut<'data, T: Pod> {
length: &'data mut PodU32,
data: &'data mut [T],
max_length: usize,
}
impl<'data, T: Pod> PodSliceMut<'data, T> {
fn unpack_internal<'a>(data: &'a mut [u8], init: bool) -> Result<Self, ProgramError>
where
'a: 'data,
{
if data.len() < LENGTH_SIZE {
return Err(TlvError::BufferTooSmall.into());
}
let (length, data) = data.split_at_mut(LENGTH_SIZE);
let length = pod_from_bytes_mut::<PodU32>(length)?;
if init {
*length = 0.into();
}
let max_length = max_len_for_type::<T>(data.len())?;
let data = pod_slice_from_bytes_mut(data)?;
Ok(Self {
length,
data,
max_length,
})
}
pub fn unpack<'a>(data: &'a mut [u8]) -> Result<Self, ProgramError>
where
'a: 'data,
{
Self::unpack_internal(data, false)
}
pub fn init<'a>(data: &'a mut [u8]) -> Result<Self, ProgramError>
where
'a: 'data,
{
Self::unpack_internal(data, true)
}
pub fn push(&mut self, t: T) -> Result<(), ProgramError> {
let length = u32::from(*self.length);
if length as usize == self.max_length {
Err(TlvError::BufferTooSmall.into())
} else {
self.data[length as usize] = t;
*self.length = length.saturating_add(1).into();
Ok(())
}
}
}
fn max_len_for_type<T>(data_len: usize) -> Result<usize, ProgramError> {
let size: usize = std::mem::size_of::<T>();
let max_len = data_len
.checked_div(size)
.ok_or(TlvError::CalculationFailure)?;
if max_len.saturating_mul(size) != data_len {
Err(TlvError::BufferTooLarge.into())
} else {
Ok(max_len)
}
}