use {
crate::{
discriminator::{Discriminator, TlvDiscriminator},
error::TlvError,
length::Length,
pod::{pod_from_bytes, pod_from_bytes_mut},
},
bytemuck::Pod,
solana_program::program_error::ProgramError,
std::{cmp::Ordering, mem::size_of},
};
const fn get_indices_unchecked(type_start: usize) -> TlvIndices {
let length_start = type_start.saturating_add(size_of::<Discriminator>());
let value_start = length_start.saturating_add(size_of::<Length>());
TlvIndices {
type_start,
length_start,
value_start,
}
}
#[derive(Debug)]
struct TlvIndices {
pub type_start: usize,
pub length_start: usize,
pub value_start: usize,
}
fn get_indices(
tlv_data: &[u8],
value_discriminator: Discriminator,
init: bool,
) -> Result<TlvIndices, ProgramError> {
let mut start_index = 0;
while start_index < tlv_data.len() {
let tlv_indices = get_indices_unchecked(start_index);
if tlv_data.len() < tlv_indices.value_start {
return Err(ProgramError::InvalidAccountData);
}
let discriminator =
Discriminator::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
if discriminator == value_discriminator {
return Ok(tlv_indices);
} else if discriminator == Discriminator::UNINITIALIZED {
if init {
return Ok(tlv_indices);
} else {
return Err(TlvError::TypeNotFound.into());
}
} else {
let length = pod_from_bytes::<Length>(
&tlv_data[tlv_indices.length_start..tlv_indices.value_start],
)?;
let value_end_index = tlv_indices
.value_start
.saturating_add(usize::try_from(*length)?);
start_index = value_end_index;
}
}
Err(ProgramError::InvalidAccountData)
}
fn get_discriminators_and_end_index(
tlv_data: &[u8],
) -> Result<(Vec<Discriminator>, usize), ProgramError> {
let mut discriminators = vec![];
let mut start_index = 0;
while start_index < tlv_data.len() {
let tlv_indices = get_indices_unchecked(start_index);
if tlv_data.len() < tlv_indices.length_start {
let remainder = &tlv_data[tlv_indices.type_start..];
if remainder.iter().all(|&x| x == 0) {
return Ok((discriminators, tlv_indices.type_start));
} else {
return Err(ProgramError::InvalidAccountData);
}
}
let discriminator =
Discriminator::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
if discriminator == Discriminator::UNINITIALIZED {
return Ok((discriminators, tlv_indices.type_start));
} else {
if tlv_data.len() < tlv_indices.value_start {
return Err(ProgramError::InvalidAccountData);
}
discriminators.push(discriminator);
let length = pod_from_bytes::<Length>(
&tlv_data[tlv_indices.length_start..tlv_indices.value_start],
)?;
let value_end_index = tlv_indices
.value_start
.saturating_add(usize::try_from(*length)?);
if value_end_index > tlv_data.len() {
return Err(ProgramError::InvalidAccountData);
}
start_index = value_end_index;
}
}
Ok((discriminators, start_index))
}
fn get_bytes<V: TlvDiscriminator>(tlv_data: &[u8]) -> Result<&[u8], ProgramError> {
let TlvIndices {
type_start: _,
length_start,
value_start,
} = get_indices(tlv_data, V::TLV_DISCRIMINATOR, false)?;
let length = pod_from_bytes::<Length>(&tlv_data[length_start..value_start])?;
let value_end = value_start.saturating_add(usize::try_from(*length)?);
if tlv_data.len() < value_end {
return Err(ProgramError::InvalidAccountData);
}
Ok(&tlv_data[value_start..value_end])
}
pub trait TlvState {
fn get_data(&self) -> &[u8];
fn get_value<V: TlvDiscriminator + Pod>(&self) -> Result<&V, ProgramError> {
let data = get_bytes::<V>(self.get_data())?;
pod_from_bytes::<V>(data)
}
#[cfg(feature = "borsh")]
fn borsh_deserialize<V: TlvDiscriminator + borsh::BorshDeserialize>(
&self,
) -> Result<V, ProgramError> {
let data = get_bytes::<V>(self.get_data())?;
solana_program::borsh::try_from_slice_unchecked::<V>(data).map_err(Into::into)
}
fn get_bytes<V: TlvDiscriminator>(&self) -> Result<&[u8], ProgramError> {
get_bytes::<V>(self.get_data())
}
fn get_discriminators(&self) -> Result<Vec<Discriminator>, ProgramError> {
get_discriminators_and_end_index(self.get_data()).map(|v| v.0)
}
fn get_base_len() -> usize {
get_base_len()
}
}
#[derive(Debug, PartialEq)]
pub struct TlvStateOwned {
data: Vec<u8>,
}
impl TlvStateOwned {
pub fn unpack(data: Vec<u8>) -> Result<Self, ProgramError> {
check_data(&data)?;
Ok(Self { data })
}
}
impl TlvState for TlvStateOwned {
fn get_data(&self) -> &[u8] {
&self.data
}
}
#[derive(Debug, PartialEq)]
pub struct TlvStateBorrowed<'data> {
data: &'data [u8],
}
impl<'data> TlvStateBorrowed<'data> {
pub fn unpack(data: &'data [u8]) -> Result<Self, ProgramError> {
check_data(data)?;
Ok(Self { data })
}
}
impl<'a> TlvState for TlvStateBorrowed<'a> {
fn get_data(&self) -> &[u8] {
self.data
}
}
#[derive(Debug, PartialEq)]
pub struct TlvStateMut<'data> {
data: &'data mut [u8],
}
impl<'data> TlvStateMut<'data> {
pub fn unpack(data: &'data mut [u8]) -> Result<Self, ProgramError> {
check_data(data)?;
Ok(Self { data })
}
pub fn get_value_mut<V: TlvDiscriminator + Pod>(&mut self) -> Result<&mut V, ProgramError> {
let data = self.get_bytes_mut::<V>()?;
pod_from_bytes_mut::<V>(data)
}
pub fn get_bytes_mut<V: TlvDiscriminator>(&mut self) -> Result<&mut [u8], ProgramError> {
let TlvIndices {
type_start: _,
length_start,
value_start,
} = get_indices(self.data, V::TLV_DISCRIMINATOR, false)?;
let length = pod_from_bytes::<Length>(&self.data[length_start..value_start])?;
let value_end = value_start.saturating_add(usize::try_from(*length)?);
if self.data.len() < value_end {
return Err(ProgramError::InvalidAccountData);
}
Ok(&mut self.data[value_start..value_end])
}
pub fn init_value<V: TlvDiscriminator + Pod + Default>(
&mut self,
) -> Result<&mut V, ProgramError> {
let length = size_of::<V>();
let buffer = self.alloc::<V>(length)?;
let extension_ref = pod_from_bytes_mut::<V>(buffer)?;
*extension_ref = V::default();
Ok(extension_ref)
}
#[cfg(feature = "borsh")]
pub fn borsh_serialize<V: TlvDiscriminator + borsh::BorshSerialize>(
&mut self,
value: &V,
) -> Result<(), ProgramError> {
let data = self.get_bytes_mut::<V>()?;
borsh::to_writer(&mut data[..], value).map_err(Into::into)
}
pub fn alloc<V: TlvDiscriminator>(&mut self, length: usize) -> Result<&mut [u8], ProgramError> {
let TlvIndices {
type_start,
length_start,
value_start,
} = get_indices(self.data, V::TLV_DISCRIMINATOR, true)?;
let discriminator = Discriminator::try_from(&self.data[type_start..length_start])?;
if discriminator == Discriminator::UNINITIALIZED {
let discriminator_ref = &mut self.data[type_start..length_start];
discriminator_ref.copy_from_slice(V::TLV_DISCRIMINATOR.as_ref());
let length_ref =
pod_from_bytes_mut::<Length>(&mut self.data[length_start..value_start])?;
*length_ref = Length::try_from(length)?;
let value_end = value_start.saturating_add(length);
if self.data.len() < value_end {
return Err(ProgramError::InvalidAccountData);
}
Ok(&mut self.data[value_start..value_end])
} else {
Err(TlvError::TypeAlreadyExists.into())
}
}
pub fn realloc<V: TlvDiscriminator>(
&mut self,
length: usize,
) -> Result<&mut [u8], ProgramError> {
let TlvIndices {
type_start: _,
length_start,
value_start,
} = get_indices(self.data, V::TLV_DISCRIMINATOR, false)?;
let (_, end_index) = get_discriminators_and_end_index(self.data)?;
let data_len = self.data.len();
let length_ref = pod_from_bytes_mut::<Length>(&mut self.data[length_start..value_start])?;
let old_length = usize::try_from(*length_ref)?;
if old_length < length {
let new_end_index = end_index.saturating_add(length.saturating_sub(old_length));
if new_end_index > data_len {
return Err(ProgramError::InvalidAccountData);
}
}
*length_ref = Length::try_from(length)?;
let old_value_end = value_start.saturating_add(old_length);
let new_value_end = value_start.saturating_add(length);
self.data
.copy_within(old_value_end..end_index, new_value_end);
match old_length.cmp(&length) {
Ordering::Greater => {
let new_end_index = end_index.saturating_sub(old_length.saturating_sub(length));
self.data[new_end_index..end_index].fill(0);
}
Ordering::Less => {
self.data[old_value_end..new_value_end].fill(0);
}
Ordering::Equal => {} }
Ok(&mut self.data[value_start..new_value_end])
}
}
impl<'a> TlvState for TlvStateMut<'a> {
fn get_data(&self) -> &[u8] {
self.data
}
}
const fn get_base_len() -> usize {
let indices = get_indices_unchecked(0);
indices.value_start
}
fn check_data(tlv_data: &[u8]) -> Result<(), ProgramError> {
let _ = get_discriminators_and_end_index(tlv_data)?;
Ok(())
}
#[cfg(test)]
mod test {
use super::*;
use bytemuck::{Pod, Zeroable};
const TEST_BUFFER: &[u8] = &[
1, 1, 1, 1, 1, 1, 1, 1, 32, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 0, 0, ];
const TEST_BIG_BUFFER: &[u8] = &[
1, 1, 1, 1, 1, 1, 1, 1, 32, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, ];
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
struct TestValue {
data: [u8; 32],
}
impl TlvDiscriminator for TestValue {
const TLV_DISCRIMINATOR: Discriminator = Discriminator::new([1; Discriminator::LENGTH]);
}
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
struct TestSmallValue {
data: [u8; 3],
}
impl TlvDiscriminator for TestSmallValue {
const TLV_DISCRIMINATOR: Discriminator = Discriminator::new([2; Discriminator::LENGTH]);
}
#[repr(transparent)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
struct TestEmptyValue;
impl TlvDiscriminator for TestEmptyValue {
const TLV_DISCRIMINATOR: Discriminator = Discriminator::new([3; Discriminator::LENGTH]);
}
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)]
struct TestNonZeroDefault {
data: [u8; 5],
}
const TEST_NON_ZERO_DEFAULT_DATA: [u8; 5] = [4; 5];
impl TlvDiscriminator for TestNonZeroDefault {
const TLV_DISCRIMINATOR: Discriminator = Discriminator::new([4; Discriminator::LENGTH]);
}
impl Default for TestNonZeroDefault {
fn default() -> Self {
Self {
data: TEST_NON_ZERO_DEFAULT_DATA,
}
}
}
#[test]
fn unpack_opaque_buffer() {
let state = TlvStateBorrowed::unpack(TEST_BUFFER).unwrap();
let value = state.get_value::<TestValue>().unwrap();
assert_eq!(value.data, [1; 32]);
assert_eq!(
state.get_value::<TestEmptyValue>(),
Err(ProgramError::InvalidAccountData)
);
let mut test_buffer = TEST_BUFFER.to_vec();
let state = TlvStateMut::unpack(&mut test_buffer).unwrap();
let value = state.get_value::<TestValue>().unwrap();
assert_eq!(value.data, [1; 32]);
let state = TlvStateOwned::unpack(test_buffer).unwrap();
let value = state.get_value::<TestValue>().unwrap();
assert_eq!(value.data, [1; 32]);
}
#[test]
fn fail_unpack_opaque_buffer() {
let mut buffer = vec![0, 3];
assert_eq!(
TlvStateBorrowed::unpack(&buffer),
Err(ProgramError::InvalidAccountData)
);
assert_eq!(
TlvStateMut::unpack(&mut buffer),
Err(ProgramError::InvalidAccountData)
);
assert_eq!(
TlvStateMut::unpack(&mut buffer),
Err(ProgramError::InvalidAccountData)
);
let mut buffer = TEST_BUFFER.to_vec();
buffer[0] += 1;
let state = TlvStateMut::unpack(&mut buffer).unwrap();
assert_eq!(
state.get_value::<TestValue>(),
Err(ProgramError::InvalidAccountData)
);
let mut buffer = TEST_BUFFER.to_vec();
buffer[Discriminator::LENGTH] += 10;
assert_eq!(
TlvStateMut::unpack(&mut buffer),
Err(ProgramError::InvalidAccountData)
);
let mut buffer = TEST_BIG_BUFFER.to_vec();
buffer[Discriminator::LENGTH] -= 1;
let state = TlvStateMut::unpack(&mut buffer).unwrap();
assert_eq!(
state.get_value::<TestValue>(),
Err(ProgramError::InvalidArgument)
);
let buffer = &TEST_BUFFER[..TEST_BUFFER.len() - 5];
assert_eq!(
TlvStateBorrowed::unpack(buffer),
Err(ProgramError::InvalidAccountData)
);
}
#[test]
fn get_discriminators_with_opaque_buffer() {
assert_eq!(
get_discriminators_and_end_index(&[1, 0, 1, 1]).unwrap_err(),
ProgramError::InvalidAccountData,
);
assert_eq!(
get_discriminators_and_end_index(&[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).unwrap(),
(vec![Discriminator::try_from(1).unwrap()], 12)
);
assert_eq!(
get_discriminators_and_end_index(&[0, 0, 0, 0, 0, 0, 0, 0]).unwrap(),
(vec![], 0)
);
}
#[test]
fn value_pack_unpack() {
let account_size =
get_base_len() + size_of::<TestValue>() + get_base_len() + size_of::<TestSmallValue>();
let mut buffer = vec![0; account_size];
let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
let value = state.init_value::<TestValue>().unwrap();
let data = [100; 32];
value.data = data;
assert_eq!(
&state.get_discriminators().unwrap(),
&[TestValue::TLV_DISCRIMINATOR],
);
assert_eq!(&state.get_value::<TestValue>().unwrap().data, &data,);
assert_eq!(
state.init_value::<TestValue>().unwrap_err(),
TlvError::TypeAlreadyExists.into(),
);
let mut expect = vec![];
expect.extend_from_slice(TestValue::TLV_DISCRIMINATOR.as_ref());
expect.extend_from_slice(&u32::try_from(size_of::<TestValue>()).unwrap().to_le_bytes());
expect.extend_from_slice(&data);
expect.extend_from_slice(&[0; size_of::<Discriminator>()]);
expect.extend_from_slice(&[0; size_of::<Length>()]);
expect.extend_from_slice(&[0; size_of::<TestSmallValue>()]);
assert_eq!(expect, buffer);
let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
let mut unpacked = state.get_value_mut::<TestValue>().unwrap();
assert_eq!(*unpacked, TestValue { data });
let new_data = [101; 32];
unpacked.data = new_data;
let state = TlvStateBorrowed::unpack(&buffer).unwrap();
let unpacked = state.get_value::<TestValue>().unwrap();
assert_eq!(*unpacked, TestValue { data: new_data });
let mut expect = vec![];
expect.extend_from_slice(TestValue::TLV_DISCRIMINATOR.as_ref());
expect.extend_from_slice(&u32::try_from(size_of::<TestValue>()).unwrap().to_le_bytes());
expect.extend_from_slice(&new_data);
expect.extend_from_slice(&[0; size_of::<Discriminator>()]);
expect.extend_from_slice(&[0; size_of::<Length>()]);
expect.extend_from_slice(&[0; size_of::<TestSmallValue>()]);
assert_eq!(expect, buffer);
let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
let new_value = state.init_value::<TestSmallValue>().unwrap();
let small_data = [102; 3];
new_value.data = small_data;
assert_eq!(
&state.get_discriminators().unwrap(),
&[
TestValue::TLV_DISCRIMINATOR,
TestSmallValue::TLV_DISCRIMINATOR
]
);
let mut expect = vec![];
expect.extend_from_slice(TestValue::TLV_DISCRIMINATOR.as_ref());
expect.extend_from_slice(&u32::try_from(size_of::<TestValue>()).unwrap().to_le_bytes());
expect.extend_from_slice(&new_data);
expect.extend_from_slice(TestSmallValue::TLV_DISCRIMINATOR.as_ref());
expect.extend_from_slice(
&u32::try_from(size_of::<TestSmallValue>())
.unwrap()
.to_le_bytes(),
);
expect.extend_from_slice(&small_data);
assert_eq!(expect, buffer);
let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
assert_eq!(
state.init_value::<TestEmptyValue>(),
Err(ProgramError::InvalidAccountData),
);
}
#[test]
fn value_any_order() {
let account_size =
get_base_len() + size_of::<TestValue>() + get_base_len() + size_of::<TestSmallValue>();
let mut buffer = vec![0; account_size];
let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
let data = [99; 32];
let small_data = [98; 3];
let value = state.init_value::<TestValue>().unwrap();
value.data = data;
let value = state.init_value::<TestSmallValue>().unwrap();
value.data = small_data;
assert_eq!(
&state.get_discriminators().unwrap(),
&[
TestValue::TLV_DISCRIMINATOR,
TestSmallValue::TLV_DISCRIMINATOR,
]
);
let mut other_buffer = vec![0; account_size];
let mut state = TlvStateMut::unpack(&mut other_buffer).unwrap();
let value = state.init_value::<TestSmallValue>().unwrap();
value.data = small_data;
let value = state.init_value::<TestValue>().unwrap();
value.data = data;
assert_eq!(
&state.get_discriminators().unwrap(),
&[
TestSmallValue::TLV_DISCRIMINATOR,
TestValue::TLV_DISCRIMINATOR,
]
);
assert_ne!(buffer, other_buffer);
let state = TlvStateBorrowed::unpack(&buffer).unwrap();
let other_state = TlvStateBorrowed::unpack(&other_buffer).unwrap();
assert_eq!(
state.get_value::<TestValue>().unwrap(),
other_state.get_value::<TestValue>().unwrap()
);
assert_eq!(
state.get_value::<TestSmallValue>().unwrap(),
other_state.get_value::<TestSmallValue>().unwrap()
);
}
#[test]
fn init_nonzero_default() {
let account_size = get_base_len() + size_of::<TestNonZeroDefault>();
let mut buffer = vec![0; account_size];
let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
let value = state.init_value::<TestNonZeroDefault>().unwrap();
assert_eq!(value.data, TEST_NON_ZERO_DEFAULT_DATA);
}
#[test]
fn init_buffer_too_small() {
let account_size = get_base_len() + size_of::<TestValue>();
let mut buffer = vec![0; account_size - 1];
let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
let err = state.init_value::<TestValue>().unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
let discriminator_ref = &mut state.data[0..Discriminator::LENGTH];
discriminator_ref.copy_from_slice(TestValue::TLV_DISCRIMINATOR.as_ref());
state.data[Discriminator::LENGTH] = 32;
let err = state.get_value::<TestValue>().unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
assert_eq!(
state.get_discriminators().unwrap_err(),
ProgramError::InvalidAccountData
);
}
#[test]
fn value_with_no_data() {
let account_size = get_base_len() + size_of::<TestEmptyValue>();
let mut buffer = vec![0; account_size];
let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
assert_eq!(
state.get_value::<TestEmptyValue>().unwrap_err(),
TlvError::TypeNotFound.into(),
);
state.init_value::<TestEmptyValue>().unwrap();
state.get_value::<TestEmptyValue>().unwrap();
assert_eq!(
state.init_value::<TestEmptyValue>().unwrap_err(),
TlvError::TypeAlreadyExists.into(),
);
}
#[test]
fn alloc() {
let tlv_size = 1;
let account_size = get_base_len() + tlv_size;
let mut buffer = vec![0; account_size];
let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
let data = state.alloc::<TestValue>(tlv_size).unwrap();
assert_eq!(
pod_from_bytes_mut::<TestValue>(data).unwrap_err(),
ProgramError::InvalidArgument,
);
assert_eq!(
state.alloc::<TestValue>(tlv_size).unwrap_err(),
TlvError::TypeAlreadyExists.into(),
);
}
#[test]
fn realloc() {
const TLV_SIZE: usize = 10;
const EXTRA_SPACE: usize = 5;
const SMALL_SIZE: usize = 2;
const ACCOUNT_SIZE: usize = get_base_len()
+ TLV_SIZE
+ EXTRA_SPACE
+ get_base_len()
+ size_of::<TestNonZeroDefault>();
let mut buffer = vec![0; ACCOUNT_SIZE];
let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
let _ = state.alloc::<TestValue>(TLV_SIZE).unwrap();
let _ = state.init_value::<TestNonZeroDefault>().unwrap();
let data = state.realloc::<TestValue>(TLV_SIZE + EXTRA_SPACE).unwrap();
assert_eq!(data, [0; TLV_SIZE + EXTRA_SPACE]);
let value = state.get_value::<TestNonZeroDefault>().unwrap();
assert_eq!(*value, TestNonZeroDefault::default());
let data = state.realloc::<TestValue>(SMALL_SIZE).unwrap();
assert_eq!(data, [0; SMALL_SIZE]);
let value = state.get_value::<TestNonZeroDefault>().unwrap();
assert_eq!(*value, TestNonZeroDefault::default());
let (_, end_index) = get_discriminators_and_end_index(&buffer).unwrap();
assert_eq!(
&buffer[end_index..ACCOUNT_SIZE],
[0; TLV_SIZE + EXTRA_SPACE - SMALL_SIZE]
);
let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
assert_eq!(
state
.realloc::<TestValue>(TLV_SIZE + EXTRA_SPACE + 1)
.unwrap_err(),
ProgramError::InvalidAccountData,
);
}
}
#[cfg(all(test, feature = "borsh"))]
mod borsh_test {
use super::*;
#[derive(Clone, Debug, PartialEq, borsh::BorshDeserialize, borsh::BorshSerialize)]
struct TestBorsh {
data: String, inner: TestInnerBorsh,
}
#[derive(Clone, Debug, PartialEq, borsh::BorshDeserialize, borsh::BorshSerialize)]
struct TestInnerBorsh {
data: String,
}
impl TlvDiscriminator for TestBorsh {
const TLV_DISCRIMINATOR: Discriminator = Discriminator::new([5; Discriminator::LENGTH]);
}
#[test]
fn borsh_value() {
let initial_data = "This is a pretty cool test!";
let initial_inner_data = "And it gets even cooler!";
let tlv_size = 4 + initial_data.len() + 4 + initial_inner_data.len();
let account_size = get_base_len() + tlv_size;
let mut buffer = vec![0; account_size];
let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
let _ = state.alloc::<TestBorsh>(tlv_size).unwrap();
let test_borsh = TestBorsh {
data: initial_data.to_string(),
inner: TestInnerBorsh {
data: initial_inner_data.to_string(),
},
};
state.borsh_serialize(&test_borsh).unwrap();
let deser = state.borsh_deserialize::<TestBorsh>().unwrap();
assert_eq!(deser, test_borsh);
let too_much_data = "This is a pretty cool test!?";
assert_eq!(
state
.borsh_serialize(&TestBorsh {
data: too_much_data.to_string(),
inner: TestInnerBorsh {
data: initial_inner_data.to_string(),
}
})
.unwrap_err(),
ProgramError::BorshIoError("failed to write whole buffer".to_string()),
);
}
}