1use {
4 crate::{error::TlvError, length::Length, variable_len_pack::VariableLenPack},
5 bytemuck::Pod,
6 solana_account_info::AccountInfo,
7 solana_program_error::ProgramError,
8 spl_discriminator::{ArrayDiscriminator, SplDiscriminate},
9 spl_pod::bytemuck::{pod_from_bytes, pod_from_bytes_mut},
10 std::{cmp::Ordering, mem::size_of},
11};
12
13const fn get_indices_unchecked(type_start: usize, value_repetition_number: usize) -> TlvIndices {
15 let length_start = type_start.saturating_add(size_of::<ArrayDiscriminator>());
16 let value_start = length_start.saturating_add(size_of::<Length>());
17 TlvIndices {
18 type_start,
19 length_start,
20 value_start,
21 value_repetition_number,
22 }
23}
24
25#[derive(Debug)]
28struct TlvIndices {
29 pub type_start: usize,
30 pub length_start: usize,
31 pub value_start: usize,
32 pub value_repetition_number: usize,
33}
34
35fn get_indices(
36 tlv_data: &[u8],
37 value_discriminator: ArrayDiscriminator,
38 init: bool,
39 repetition_number: Option<usize>,
40) -> Result<TlvIndices, ProgramError> {
41 let mut current_repetition_number = 0;
42 let mut start_index = 0;
43 while start_index < tlv_data.len() {
44 let tlv_indices = get_indices_unchecked(start_index, current_repetition_number);
45 if tlv_data.len() < tlv_indices.value_start {
46 return Err(ProgramError::InvalidAccountData);
47 }
48 let discriminator = ArrayDiscriminator::try_from(
49 &tlv_data[tlv_indices.type_start..tlv_indices.length_start],
50 )?;
51 if discriminator == value_discriminator {
52 if let Some(desired_repetition_number) = repetition_number {
53 if current_repetition_number == desired_repetition_number {
54 return Ok(tlv_indices);
55 }
56 }
57 current_repetition_number += 1;
58 } else if discriminator == ArrayDiscriminator::UNINITIALIZED {
61 if init {
62 return Ok(tlv_indices);
63 } else {
64 return Err(TlvError::TypeNotFound.into());
65 }
66 }
67 let length =
68 pod_from_bytes::<Length>(&tlv_data[tlv_indices.length_start..tlv_indices.value_start])?;
69 let value_end_index = tlv_indices
70 .value_start
71 .saturating_add(usize::try_from(*length)?);
72 start_index = value_end_index;
73 }
74 Err(ProgramError::InvalidAccountData)
75}
76
77fn get_discriminators_and_end_index(
80 tlv_data: &[u8],
81) -> Result<(Vec<ArrayDiscriminator>, usize), ProgramError> {
82 let mut discriminators = vec![];
83 let mut start_index = 0;
84 while start_index < tlv_data.len() {
85 let tlv_indices = get_indices_unchecked(start_index, 0);
88 if tlv_data.len() < tlv_indices.length_start {
89 let remainder = &tlv_data[tlv_indices.type_start..];
91 if remainder.iter().all(|&x| x == 0) {
92 return Ok((discriminators, tlv_indices.type_start));
93 } else {
94 return Err(ProgramError::InvalidAccountData);
95 }
96 }
97 let discriminator = ArrayDiscriminator::try_from(
98 &tlv_data[tlv_indices.type_start..tlv_indices.length_start],
99 )?;
100 if discriminator == ArrayDiscriminator::UNINITIALIZED {
101 return Ok((discriminators, tlv_indices.type_start));
102 } else {
103 if tlv_data.len() < tlv_indices.value_start {
104 return Err(ProgramError::InvalidAccountData);
106 }
107 discriminators.push(discriminator);
108 let length = pod_from_bytes::<Length>(
109 &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
110 )?;
111
112 let value_end_index = tlv_indices
113 .value_start
114 .saturating_add(usize::try_from(*length)?);
115 if value_end_index > tlv_data.len() {
116 return Err(ProgramError::InvalidAccountData);
118 }
119 start_index = value_end_index;
120 }
121 }
122 Ok((discriminators, start_index))
123}
124
125fn get_bytes<V: SplDiscriminate>(
126 tlv_data: &[u8],
127 repetition_number: usize,
128) -> Result<&[u8], ProgramError> {
129 let TlvIndices {
130 type_start: _,
131 length_start,
132 value_start,
133 value_repetition_number: _,
134 } = get_indices(
135 tlv_data,
136 V::SPL_DISCRIMINATOR,
137 false,
138 Some(repetition_number),
139 )?;
140 let length = pod_from_bytes::<Length>(&tlv_data[length_start..value_start])?;
143 let value_end = value_start.saturating_add(usize::try_from(*length)?);
144 if tlv_data.len() < value_end {
145 return Err(ProgramError::InvalidAccountData);
146 }
147 Ok(&tlv_data[value_start..value_end])
148}
149
150pub trait TlvState {
207 fn get_data(&self) -> &[u8];
209
210 fn get_value_with_repetition<V: SplDiscriminate + Pod>(
213 &self,
214 repetition_number: usize,
215 ) -> Result<&V, ProgramError> {
216 let data = get_bytes::<V>(self.get_data(), repetition_number)?;
217 pod_from_bytes::<V>(data)
218 }
219
220 fn get_first_value<V: SplDiscriminate + Pod>(&self) -> Result<&V, ProgramError> {
223 self.get_value_with_repetition::<V>(0)
224 }
225
226 fn get_variable_len_value_with_repetition<V: SplDiscriminate + VariableLenPack>(
229 &self,
230 repetition_number: usize,
231 ) -> Result<V, ProgramError> {
232 let data = get_bytes::<V>(self.get_data(), repetition_number)?;
233 V::unpack_from_slice(data)
234 }
235
236 fn get_first_variable_len_value<V: SplDiscriminate + VariableLenPack>(
239 &self,
240 ) -> Result<V, ProgramError> {
241 self.get_variable_len_value_with_repetition::<V>(0)
242 }
243
244 fn get_bytes_with_repetition<V: SplDiscriminate>(
246 &self,
247 repetition_number: usize,
248 ) -> Result<&[u8], ProgramError> {
249 get_bytes::<V>(self.get_data(), repetition_number)
250 }
251
252 fn get_first_bytes<V: SplDiscriminate>(&self) -> Result<&[u8], ProgramError> {
254 self.get_bytes_with_repetition::<V>(0)
255 }
256
257 fn get_discriminators(&self) -> Result<Vec<ArrayDiscriminator>, ProgramError> {
259 get_discriminators_and_end_index(self.get_data()).map(|v| v.0)
260 }
261
262 fn get_base_len() -> usize {
264 get_base_len()
265 }
266}
267
268#[derive(Debug, PartialEq)]
270pub struct TlvStateOwned {
271 data: Vec<u8>,
273}
274impl TlvStateOwned {
275 pub fn unpack(data: Vec<u8>) -> Result<Self, ProgramError> {
279 check_data(&data)?;
280 Ok(Self { data })
281 }
282}
283impl TlvState for TlvStateOwned {
284 fn get_data(&self) -> &[u8] {
285 &self.data
286 }
287}
288
289#[derive(Debug, PartialEq)]
292pub struct TlvStateBorrowed<'data> {
293 data: &'data [u8],
295}
296impl<'data> TlvStateBorrowed<'data> {
297 pub fn unpack(data: &'data [u8]) -> Result<Self, ProgramError> {
301 check_data(data)?;
302 Ok(Self { data })
303 }
304}
305impl TlvState for TlvStateBorrowed<'_> {
306 fn get_data(&self) -> &[u8] {
307 self.data
308 }
309}
310
311#[derive(Debug, PartialEq)]
314pub struct TlvStateMut<'data> {
315 data: &'data mut [u8],
317}
318impl<'data> TlvStateMut<'data> {
319 pub fn unpack(data: &'data mut [u8]) -> Result<Self, ProgramError> {
323 check_data(data)?;
324 Ok(Self { data })
325 }
326
327 pub fn get_value_with_repetition_mut<V: SplDiscriminate + Pod>(
330 &mut self,
331 repetition_number: usize,
332 ) -> Result<&mut V, ProgramError> {
333 let data = self.get_bytes_with_repetition_mut::<V>(repetition_number)?;
334 pod_from_bytes_mut::<V>(data)
335 }
336
337 pub fn get_first_value_mut<V: SplDiscriminate + Pod>(
340 &mut self,
341 ) -> Result<&mut V, ProgramError> {
342 self.get_value_with_repetition_mut::<V>(0)
343 }
344
345 pub fn get_bytes_with_repetition_mut<V: SplDiscriminate>(
348 &mut self,
349 repetition_number: usize,
350 ) -> Result<&mut [u8], ProgramError> {
351 let TlvIndices {
352 type_start: _,
353 length_start,
354 value_start,
355 value_repetition_number: _,
356 } = get_indices(
357 self.data,
358 V::SPL_DISCRIMINATOR,
359 false,
360 Some(repetition_number),
361 )?;
362
363 let length = pod_from_bytes::<Length>(&self.data[length_start..value_start])?;
364 let value_end = value_start.saturating_add(usize::try_from(*length)?);
365 if self.data.len() < value_end {
366 return Err(ProgramError::InvalidAccountData);
367 }
368 Ok(&mut self.data[value_start..value_end])
369 }
370
371 pub fn get_first_bytes_mut<V: SplDiscriminate>(&mut self) -> Result<&mut [u8], ProgramError> {
374 self.get_bytes_with_repetition_mut::<V>(0)
375 }
376
377 pub fn init_value<V: SplDiscriminate + Pod + Default>(
383 &mut self,
384 allow_repetition: bool,
385 ) -> Result<(&mut V, usize), ProgramError> {
386 let length = size_of::<V>();
387 let (buffer, repetition_number) = self.alloc::<V>(length, allow_repetition)?;
388 let extension_ref = pod_from_bytes_mut::<V>(buffer)?;
389 *extension_ref = V::default();
390 Ok((extension_ref, repetition_number))
391 }
392
393 pub fn pack_variable_len_value_with_repetition<V: SplDiscriminate + VariableLenPack>(
396 &mut self,
397 value: &V,
398 repetition_number: usize,
399 ) -> Result<(), ProgramError> {
400 let data = self.get_bytes_with_repetition_mut::<V>(repetition_number)?;
401 value.pack_into_slice(data)
404 }
405
406 pub fn pack_first_variable_len_value<V: SplDiscriminate + VariableLenPack>(
409 &mut self,
410 value: &V,
411 ) -> Result<(), ProgramError> {
412 self.pack_variable_len_value_with_repetition::<V>(value, 0)
413 }
414
415 pub fn alloc<V: SplDiscriminate>(
417 &mut self,
418 length: usize,
419 allow_repetition: bool,
420 ) -> Result<(&mut [u8], usize), ProgramError> {
421 let TlvIndices {
422 type_start,
423 length_start,
424 value_start,
425 value_repetition_number,
426 } = get_indices(
427 self.data,
428 V::SPL_DISCRIMINATOR,
429 true,
430 if allow_repetition { None } else { Some(0) },
431 )?;
432
433 let discriminator = ArrayDiscriminator::try_from(&self.data[type_start..length_start])?;
434 if discriminator == ArrayDiscriminator::UNINITIALIZED {
435 let discriminator_ref = &mut self.data[type_start..length_start];
437 discriminator_ref.copy_from_slice(V::SPL_DISCRIMINATOR.as_ref());
438 let length_ref =
440 pod_from_bytes_mut::<Length>(&mut self.data[length_start..value_start])?;
441 *length_ref = Length::try_from(length)?;
442
443 let value_end = value_start.saturating_add(length);
444 if self.data.len() < value_end {
445 return Err(ProgramError::InvalidAccountData);
446 }
447 Ok((
448 &mut self.data[value_start..value_end],
449 value_repetition_number,
450 ))
451 } else {
452 Err(TlvError::TypeAlreadyExists.into())
453 }
454 }
455
456 pub fn alloc_and_pack_variable_len_entry<V: SplDiscriminate + VariableLenPack>(
458 &mut self,
459 value: &V,
460 allow_repetition: bool,
461 ) -> Result<usize, ProgramError> {
462 let length = value.get_packed_len()?;
463 let (data, repetition_number) = self.alloc::<V>(length, allow_repetition)?;
464 value.pack_into_slice(data)?;
465 Ok(repetition_number)
466 }
467
468 pub fn realloc_with_repetition<V: SplDiscriminate>(
473 &mut self,
474 length: usize,
475 repetition_number: usize,
476 ) -> Result<&mut [u8], ProgramError> {
477 let TlvIndices {
478 type_start: _,
479 length_start,
480 value_start,
481 value_repetition_number: _,
482 } = get_indices(
483 self.data,
484 V::SPL_DISCRIMINATOR,
485 false,
486 Some(repetition_number),
487 )?;
488 let (_, end_index) = get_discriminators_and_end_index(self.data)?;
489 let data_len = self.data.len();
490
491 let length_ref = pod_from_bytes_mut::<Length>(&mut self.data[length_start..value_start])?;
492 let old_length = usize::try_from(*length_ref)?;
493
494 if old_length < length {
496 let new_end_index = end_index.saturating_add(length.saturating_sub(old_length));
497 if new_end_index > data_len {
498 return Err(ProgramError::InvalidAccountData);
499 }
500 }
501
502 *length_ref = Length::try_from(length)?;
505
506 let old_value_end = value_start.saturating_add(old_length);
507 let new_value_end = value_start.saturating_add(length);
508 self.data
509 .copy_within(old_value_end..end_index, new_value_end);
510 match old_length.cmp(&length) {
511 Ordering::Greater => {
512 let new_end_index = end_index.saturating_sub(old_length.saturating_sub(length));
514 self.data[new_end_index..end_index].fill(0);
515 }
516 Ordering::Less => {
517 self.data[old_value_end..new_value_end].fill(0);
519 }
520 Ordering::Equal => {} }
522
523 Ok(&mut self.data[value_start..new_value_end])
524 }
525
526 pub fn realloc_first<V: SplDiscriminate>(
529 &mut self,
530 length: usize,
531 ) -> Result<&mut [u8], ProgramError> {
532 self.realloc_with_repetition::<V>(length, 0)
533 }
534}
535
536impl TlvState for TlvStateMut<'_> {
537 fn get_data(&self) -> &[u8] {
538 self.data
539 }
540}
541
542pub fn realloc_and_pack_variable_len_with_repetition<V: SplDiscriminate + VariableLenPack>(
545 account_info: &AccountInfo,
546 value: &V,
547 repetition_number: usize,
548) -> Result<(), ProgramError> {
549 let previous_length = {
550 let data = account_info.try_borrow_data()?;
551 let TlvIndices {
552 type_start: _,
553 length_start,
554 value_start,
555 value_repetition_number: _,
556 } = get_indices(&data, V::SPL_DISCRIMINATOR, false, Some(repetition_number))?;
557 usize::try_from(*pod_from_bytes::<Length>(&data[length_start..value_start])?)?
558 };
559 let new_length = value.get_packed_len()?;
560 let previous_account_size = account_info.try_data_len()?;
561 if previous_length < new_length {
562 let additional_bytes = new_length
564 .checked_sub(previous_length)
565 .ok_or(ProgramError::AccountDataTooSmall)?;
566 account_info.realloc(previous_account_size.saturating_add(additional_bytes), true)?;
567 let mut buffer = account_info.try_borrow_mut_data()?;
568 let mut state = TlvStateMut::unpack(&mut buffer)?;
569 state.realloc_with_repetition::<V>(new_length, repetition_number)?;
570 state.pack_variable_len_value_with_repetition(value, repetition_number)?;
571 } else {
572 let mut buffer = account_info.try_borrow_mut_data()?;
574 let mut state = TlvStateMut::unpack(&mut buffer)?;
575 state.pack_variable_len_value_with_repetition(value, repetition_number)?;
576 let removed_bytes = previous_length
577 .checked_sub(new_length)
578 .ok_or(ProgramError::AccountDataTooSmall)?;
579 if removed_bytes > 0 {
580 state.realloc_with_repetition::<V>(new_length, repetition_number)?;
582 drop(buffer);
584 account_info.realloc(previous_account_size.saturating_sub(removed_bytes), false)?;
585 }
586 }
587 Ok(())
588}
589
590pub fn realloc_and_pack_first_variable_len<V: SplDiscriminate + VariableLenPack>(
593 account_info: &AccountInfo,
594 value: &V,
595) -> Result<(), ProgramError> {
596 realloc_and_pack_variable_len_with_repetition::<V>(account_info, value, 0)
597}
598
599const fn get_base_len() -> usize {
601 get_indices_unchecked(0, 0).value_start
602}
603
604fn check_data(tlv_data: &[u8]) -> Result<(), ProgramError> {
605 let _ = get_discriminators_and_end_index(tlv_data)?;
607 Ok(())
608}
609
610#[cfg(test)]
611mod test {
612 use {
613 super::*,
614 bytemuck::{Pod, Zeroable},
615 };
616
617 const TEST_BUFFER: &[u8] = &[
618 1, 1, 1, 1, 1, 1, 1, 1, 32, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
621 1, 1, 0, 0, ];
624
625 const TEST_BIG_BUFFER: &[u8] = &[
626 1, 1, 1, 1, 1, 1, 1, 1, 32, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
629 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
631 0, ];
633
634 #[repr(C)]
635 #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
636 struct TestValue {
637 data: [u8; 32],
638 }
639 impl SplDiscriminate for TestValue {
640 const SPL_DISCRIMINATOR: ArrayDiscriminator =
641 ArrayDiscriminator::new([1; ArrayDiscriminator::LENGTH]);
642 }
643
644 #[repr(C)]
645 #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
646 struct TestSmallValue {
647 data: [u8; 3],
648 }
649 impl SplDiscriminate for TestSmallValue {
650 const SPL_DISCRIMINATOR: ArrayDiscriminator =
651 ArrayDiscriminator::new([2; ArrayDiscriminator::LENGTH]);
652 }
653
654 #[repr(transparent)]
655 #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
656 struct TestEmptyValue;
657 impl SplDiscriminate for TestEmptyValue {
658 const SPL_DISCRIMINATOR: ArrayDiscriminator =
659 ArrayDiscriminator::new([3; ArrayDiscriminator::LENGTH]);
660 }
661
662 #[repr(C)]
663 #[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)]
664 struct TestNonZeroDefault {
665 data: [u8; 5],
666 }
667 const TEST_NON_ZERO_DEFAULT_DATA: [u8; 5] = [4; 5];
668 impl SplDiscriminate for TestNonZeroDefault {
669 const SPL_DISCRIMINATOR: ArrayDiscriminator =
670 ArrayDiscriminator::new([4; ArrayDiscriminator::LENGTH]);
671 }
672 impl Default for TestNonZeroDefault {
673 fn default() -> Self {
674 Self {
675 data: TEST_NON_ZERO_DEFAULT_DATA,
676 }
677 }
678 }
679
680 #[test]
681 fn unpack_opaque_buffer() {
682 let state = TlvStateBorrowed::unpack(TEST_BUFFER).unwrap();
683 let value = state.get_first_value::<TestValue>().unwrap();
684 assert_eq!(value.data, [1; 32]);
685 assert_eq!(
686 state.get_first_value::<TestEmptyValue>(),
687 Err(ProgramError::InvalidAccountData)
688 );
689
690 let mut test_buffer = TEST_BUFFER.to_vec();
691 let state = TlvStateMut::unpack(&mut test_buffer).unwrap();
692 let value = state.get_first_value::<TestValue>().unwrap();
693 assert_eq!(value.data, [1; 32]);
694 let state = TlvStateOwned::unpack(test_buffer).unwrap();
695 let value = state.get_first_value::<TestValue>().unwrap();
696 assert_eq!(value.data, [1; 32]);
697 }
698
699 #[test]
700 fn fail_unpack_opaque_buffer() {
701 let mut buffer = vec![0, 3];
703 assert_eq!(
704 TlvStateBorrowed::unpack(&buffer),
705 Err(ProgramError::InvalidAccountData)
706 );
707 assert_eq!(
708 TlvStateMut::unpack(&mut buffer),
709 Err(ProgramError::InvalidAccountData)
710 );
711 assert_eq!(
712 TlvStateMut::unpack(&mut buffer),
713 Err(ProgramError::InvalidAccountData)
714 );
715
716 let mut buffer = TEST_BUFFER.to_vec();
718 buffer[0] += 1;
719 let state = TlvStateMut::unpack(&mut buffer).unwrap();
720 assert_eq!(
721 state.get_first_value::<TestValue>(),
722 Err(ProgramError::InvalidAccountData)
723 );
724
725 let mut buffer = TEST_BUFFER.to_vec();
727 buffer[ArrayDiscriminator::LENGTH] += 10;
728 assert_eq!(
729 TlvStateMut::unpack(&mut buffer),
730 Err(ProgramError::InvalidAccountData)
731 );
732
733 let mut buffer = TEST_BIG_BUFFER.to_vec();
735 buffer[ArrayDiscriminator::LENGTH] -= 1;
736 let state = TlvStateMut::unpack(&mut buffer).unwrap();
737 assert_eq!(
738 state.get_first_value::<TestValue>(),
739 Err(ProgramError::InvalidArgument)
740 );
741
742 let buffer = &TEST_BUFFER[..TEST_BUFFER.len() - 5];
744 assert_eq!(
745 TlvStateBorrowed::unpack(buffer),
746 Err(ProgramError::InvalidAccountData)
747 );
748 }
749
750 #[test]
751 fn get_discriminators_with_opaque_buffer() {
752 assert_eq!(
754 get_discriminators_and_end_index(&[1, 0, 1, 1]).unwrap_err(),
755 ProgramError::InvalidAccountData,
756 );
757 assert_eq!(
759 get_discriminators_and_end_index(&[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).unwrap(),
760 (vec![ArrayDiscriminator::from(1)], 12)
761 );
762 assert_eq!(
764 get_discriminators_and_end_index(&[0, 0, 0, 0, 0, 0, 0, 0]).unwrap(),
765 (vec![], 0)
766 );
767 }
768
769 #[test]
770 fn value_pack_unpack() {
771 let account_size =
772 get_base_len() + size_of::<TestValue>() + get_base_len() + size_of::<TestSmallValue>();
773 let mut buffer = vec![0; account_size];
774
775 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
776
777 let value = state.init_value::<TestValue>(false).unwrap().0;
779 let data = [100; 32];
780 value.data = data;
781 assert_eq!(
782 &state.get_discriminators().unwrap(),
783 &[TestValue::SPL_DISCRIMINATOR],
784 );
785 assert_eq!(&state.get_first_value::<TestValue>().unwrap().data, &data,);
786
787 assert_eq!(
789 state.init_value::<TestValue>(false).unwrap_err(),
790 TlvError::TypeAlreadyExists.into(),
791 );
792
793 let mut expect = vec![];
795 expect.extend_from_slice(TestValue::SPL_DISCRIMINATOR.as_ref());
796 expect.extend_from_slice(&u32::try_from(size_of::<TestValue>()).unwrap().to_le_bytes());
797 expect.extend_from_slice(&data);
798 expect.extend_from_slice(&[0; size_of::<ArrayDiscriminator>()]);
799 expect.extend_from_slice(&[0; size_of::<Length>()]);
800 expect.extend_from_slice(&[0; size_of::<TestSmallValue>()]);
801 assert_eq!(expect, buffer);
802
803 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
805 let unpacked = state.get_first_value_mut::<TestValue>().unwrap();
806 assert_eq!(*unpacked, TestValue { data });
807
808 let new_data = [101; 32];
810 unpacked.data = new_data;
811
812 let state = TlvStateBorrowed::unpack(&buffer).unwrap();
814 let unpacked = state.get_first_value::<TestValue>().unwrap();
815 assert_eq!(*unpacked, TestValue { data: new_data });
816
817 let mut expect = vec![];
819 expect.extend_from_slice(TestValue::SPL_DISCRIMINATOR.as_ref());
820 expect.extend_from_slice(&u32::try_from(size_of::<TestValue>()).unwrap().to_le_bytes());
821 expect.extend_from_slice(&new_data);
822 expect.extend_from_slice(&[0; size_of::<ArrayDiscriminator>()]);
823 expect.extend_from_slice(&[0; size_of::<Length>()]);
824 expect.extend_from_slice(&[0; size_of::<TestSmallValue>()]);
825 assert_eq!(expect, buffer);
826
827 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
828 let new_value = state.init_value::<TestSmallValue>(false).unwrap().0;
830 let small_data = [102; 3];
831 new_value.data = small_data;
832
833 assert_eq!(
834 &state.get_discriminators().unwrap(),
835 &[
836 TestValue::SPL_DISCRIMINATOR,
837 TestSmallValue::SPL_DISCRIMINATOR
838 ]
839 );
840
841 let mut expect = vec![];
843 expect.extend_from_slice(TestValue::SPL_DISCRIMINATOR.as_ref());
844 expect.extend_from_slice(&u32::try_from(size_of::<TestValue>()).unwrap().to_le_bytes());
845 expect.extend_from_slice(&new_data);
846 expect.extend_from_slice(TestSmallValue::SPL_DISCRIMINATOR.as_ref());
847 expect.extend_from_slice(
848 &u32::try_from(size_of::<TestSmallValue>())
849 .unwrap()
850 .to_le_bytes(),
851 );
852 expect.extend_from_slice(&small_data);
853 assert_eq!(expect, buffer);
854
855 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
857 assert_eq!(
858 state.init_value::<TestEmptyValue>(false),
859 Err(ProgramError::InvalidAccountData),
860 );
861 }
862
863 #[test]
864 fn value_any_order() {
865 let account_size =
866 get_base_len() + size_of::<TestValue>() + get_base_len() + size_of::<TestSmallValue>();
867 let mut buffer = vec![0; account_size];
868
869 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
870
871 let data = [99; 32];
872 let small_data = [98; 3];
873
874 let value = state.init_value::<TestValue>(false).unwrap().0;
876 value.data = data;
877 let value = state.init_value::<TestSmallValue>(false).unwrap().0;
878 value.data = small_data;
879
880 assert_eq!(
881 &state.get_discriminators().unwrap(),
882 &[
883 TestValue::SPL_DISCRIMINATOR,
884 TestSmallValue::SPL_DISCRIMINATOR,
885 ]
886 );
887
888 let mut other_buffer = vec![0; account_size];
890 let mut state = TlvStateMut::unpack(&mut other_buffer).unwrap();
891
892 let value = state.init_value::<TestSmallValue>(false).unwrap().0;
893 value.data = small_data;
894 let value = state.init_value::<TestValue>(false).unwrap().0;
895 value.data = data;
896
897 assert_eq!(
898 &state.get_discriminators().unwrap(),
899 &[
900 TestSmallValue::SPL_DISCRIMINATOR,
901 TestValue::SPL_DISCRIMINATOR,
902 ]
903 );
904
905 assert_ne!(buffer, other_buffer);
907 let state = TlvStateBorrowed::unpack(&buffer).unwrap();
908 let other_state = TlvStateBorrowed::unpack(&other_buffer).unwrap();
909
910 assert_eq!(
912 state.get_first_value::<TestValue>().unwrap(),
913 other_state.get_first_value::<TestValue>().unwrap()
914 );
915 assert_eq!(
916 state.get_first_value::<TestSmallValue>().unwrap(),
917 other_state.get_first_value::<TestSmallValue>().unwrap()
918 );
919 }
920
921 #[test]
922 fn init_nonzero_default() {
923 let account_size = get_base_len() + size_of::<TestNonZeroDefault>();
924 let mut buffer = vec![0; account_size];
925 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
926 let value = state.init_value::<TestNonZeroDefault>(false).unwrap().0;
927 assert_eq!(value.data, TEST_NON_ZERO_DEFAULT_DATA);
928 }
929
930 #[test]
931 fn init_buffer_too_small() {
932 let account_size = get_base_len() + size_of::<TestValue>();
933 let mut buffer = vec![0; account_size - 1];
934 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
935 let err = state.init_value::<TestValue>(false).unwrap_err();
936 assert_eq!(err, ProgramError::InvalidAccountData);
937
938 let discriminator_ref = &mut state.data[0..ArrayDiscriminator::LENGTH];
940 discriminator_ref.copy_from_slice(TestValue::SPL_DISCRIMINATOR.as_ref());
941 state.data[ArrayDiscriminator::LENGTH] = 32;
942 let err = state.get_first_value::<TestValue>().unwrap_err();
943 assert_eq!(err, ProgramError::InvalidAccountData);
944 assert_eq!(
945 state.get_discriminators().unwrap_err(),
946 ProgramError::InvalidAccountData
947 );
948 }
949
950 #[test]
951 fn value_with_no_data() {
952 let account_size = get_base_len() + size_of::<TestEmptyValue>();
953 let mut buffer = vec![0; account_size];
954 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
955
956 assert_eq!(
957 state.get_first_value::<TestEmptyValue>().unwrap_err(),
958 TlvError::TypeNotFound.into(),
959 );
960
961 state.init_value::<TestEmptyValue>(false).unwrap();
962 state.get_first_value::<TestEmptyValue>().unwrap();
963
964 assert_eq!(
966 state.init_value::<TestEmptyValue>(false).unwrap_err(),
967 TlvError::TypeAlreadyExists.into(),
968 );
969 }
970
971 #[test]
972 fn alloc_first() {
973 let tlv_size = 1;
974 let account_size = get_base_len() + tlv_size;
975 let mut buffer = vec![0; account_size];
976 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
977
978 let data = state.alloc::<TestValue>(tlv_size, false).unwrap().0;
980 assert_eq!(
981 pod_from_bytes_mut::<TestValue>(data).unwrap_err(),
982 ProgramError::InvalidArgument,
983 );
984
985 assert_eq!(
987 state.alloc::<TestValue>(tlv_size, false).unwrap_err(),
988 TlvError::TypeAlreadyExists.into(),
989 );
990 }
991
992 #[test]
993 fn alloc_with_repetition() {
994 let tlv_size = 1;
995 let account_size = (get_base_len() + tlv_size) * 2;
996 let mut buffer = vec![0; account_size];
997 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
998
999 let (data, repetition_number) = state.alloc::<TestValue>(tlv_size, true).unwrap();
1000 assert_eq!(repetition_number, 0);
1001
1002 assert_eq!(
1004 pod_from_bytes_mut::<TestValue>(data).unwrap_err(),
1005 ProgramError::InvalidArgument,
1006 );
1007
1008 let (_data, repetition_number) = state.alloc::<TestValue>(tlv_size, true).unwrap();
1010 assert_eq!(repetition_number, 1);
1011 }
1012
1013 #[test]
1014 fn realloc_first() {
1015 const TLV_SIZE: usize = 10;
1016 const EXTRA_SPACE: usize = 5;
1017 const SMALL_SIZE: usize = 2;
1018 const ACCOUNT_SIZE: usize = get_base_len()
1019 + TLV_SIZE
1020 + EXTRA_SPACE
1021 + get_base_len()
1022 + size_of::<TestNonZeroDefault>();
1023 let mut buffer = vec![0; ACCOUNT_SIZE];
1024 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
1025
1026 let _ = state.alloc::<TestValue>(TLV_SIZE, false).unwrap();
1028 let _ = state.init_value::<TestNonZeroDefault>(false).unwrap();
1029
1030 let data = state
1032 .realloc_first::<TestValue>(TLV_SIZE + EXTRA_SPACE)
1033 .unwrap();
1034 assert_eq!(data, [0; TLV_SIZE + EXTRA_SPACE]);
1035 let value = state.get_first_value::<TestNonZeroDefault>().unwrap();
1036 assert_eq!(*value, TestNonZeroDefault::default());
1037
1038 let data = state.realloc_first::<TestValue>(SMALL_SIZE).unwrap();
1040 assert_eq!(data, [0; SMALL_SIZE]);
1041 let value = state.get_first_value::<TestNonZeroDefault>().unwrap();
1042 assert_eq!(*value, TestNonZeroDefault::default());
1043 let (_, end_index) = get_discriminators_and_end_index(&buffer).unwrap();
1044 assert_eq!(
1045 &buffer[end_index..ACCOUNT_SIZE],
1046 [0; TLV_SIZE + EXTRA_SPACE - SMALL_SIZE]
1047 );
1048
1049 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
1051 assert_eq!(
1053 state
1054 .realloc_first::<TestValue>(TLV_SIZE + EXTRA_SPACE + 1)
1055 .unwrap_err(),
1056 ProgramError::InvalidAccountData,
1057 );
1058 }
1059
1060 #[test]
1061 fn realloc_with_repeating_entries() {
1062 const TLV_SIZE: usize = 10;
1063 const EXTRA_SPACE: usize = 5;
1064 const SMALL_SIZE: usize = 2;
1065 const ACCOUNT_SIZE: usize = get_base_len()
1066 + TLV_SIZE
1067 + EXTRA_SPACE
1068 + get_base_len()
1069 + TLV_SIZE
1070 + get_base_len()
1071 + size_of::<TestNonZeroDefault>();
1072 let mut buffer = vec![0; ACCOUNT_SIZE];
1073 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
1074
1075 let _ = state.alloc::<TestValue>(TLV_SIZE, true).unwrap();
1077 let _ = state.alloc::<TestValue>(TLV_SIZE, true).unwrap();
1078 let _ = state.init_value::<TestNonZeroDefault>(true).unwrap();
1079
1080 let data = state
1082 .realloc_with_repetition::<TestValue>(TLV_SIZE + EXTRA_SPACE, 0)
1083 .unwrap();
1084 assert_eq!(data, [0; TLV_SIZE + EXTRA_SPACE]);
1085 let value = state.get_bytes_with_repetition::<TestValue>(0).unwrap();
1086 assert_eq!(*value, [0; TLV_SIZE + EXTRA_SPACE]);
1087 let value = state.get_bytes_with_repetition::<TestValue>(1).unwrap();
1088 assert_eq!(*value, [0; TLV_SIZE]);
1089 let value = state.get_first_value::<TestNonZeroDefault>().unwrap();
1090 assert_eq!(*value, TestNonZeroDefault::default());
1091
1092 let data = state
1094 .realloc_with_repetition::<TestValue>(SMALL_SIZE, 0)
1095 .unwrap();
1096 assert_eq!(data, [0; SMALL_SIZE]);
1097 let value = state.get_bytes_with_repetition::<TestValue>(0).unwrap();
1098 assert_eq!(*value, [0; SMALL_SIZE]);
1099 let value = state.get_bytes_with_repetition::<TestValue>(1).unwrap();
1100 assert_eq!(*value, [0; TLV_SIZE]);
1101 let value = state.get_first_value::<TestNonZeroDefault>().unwrap();
1102 assert_eq!(*value, TestNonZeroDefault::default());
1103 let (_, end_index) = get_discriminators_and_end_index(&buffer).unwrap();
1104 assert_eq!(
1105 &buffer[end_index..ACCOUNT_SIZE],
1106 [0; TLV_SIZE + EXTRA_SPACE - SMALL_SIZE]
1107 );
1108
1109 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
1111 assert_eq!(
1113 state
1114 .realloc_with_repetition::<TestValue>(TLV_SIZE + EXTRA_SPACE + 1, 0)
1115 .unwrap_err(),
1116 ProgramError::InvalidAccountData,
1117 );
1118 }
1119
1120 #[derive(Clone, Debug, PartialEq)]
1121 struct TestVariableLen {
1122 data: String, }
1124 impl SplDiscriminate for TestVariableLen {
1125 const SPL_DISCRIMINATOR: ArrayDiscriminator =
1126 ArrayDiscriminator::new([5; ArrayDiscriminator::LENGTH]);
1127 }
1128 impl VariableLenPack for TestVariableLen {
1129 fn pack_into_slice(&self, dst: &mut [u8]) -> Result<(), ProgramError> {
1130 let bytes = self.data.as_bytes();
1131 let end = 8 + bytes.len();
1132 if dst.len() < end {
1133 Err(ProgramError::InvalidAccountData)
1134 } else {
1135 dst[..8].copy_from_slice(&self.data.len().to_le_bytes());
1136 dst[8..end].copy_from_slice(bytes);
1137 Ok(())
1138 }
1139 }
1140 fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
1141 let length = u64::from_le_bytes(src[..8].try_into().unwrap()) as usize;
1142 if src[8..8 + length].len() != length {
1143 return Err(ProgramError::InvalidAccountData);
1144 }
1145 let data = std::str::from_utf8(&src[8..8 + length])
1146 .unwrap()
1147 .to_string();
1148 Ok(Self { data })
1149 }
1150 fn get_packed_len(&self) -> Result<usize, ProgramError> {
1151 Ok(size_of::<u64>().saturating_add(self.data.len()))
1152 }
1153 }
1154
1155 #[test]
1156 fn first_variable_len_value() {
1157 let initial_data = "This is a pretty cool test!";
1158 let tlv_size = 8 + initial_data.len();
1160 let account_size = get_base_len() + tlv_size;
1161 let mut buffer = vec![0; account_size];
1162 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
1163
1164 let _ = state.alloc::<TestVariableLen>(tlv_size, false).unwrap();
1166 let test_variable_len = TestVariableLen {
1167 data: initial_data.to_string(),
1168 };
1169 state
1170 .pack_first_variable_len_value(&test_variable_len)
1171 .unwrap();
1172 let deser = state
1173 .get_first_variable_len_value::<TestVariableLen>()
1174 .unwrap();
1175 assert_eq!(deser, test_variable_len);
1176
1177 let too_much_data = "This is a pretty cool test!?";
1179 assert_eq!(
1180 state
1181 .pack_first_variable_len_value(&TestVariableLen {
1182 data: too_much_data.to_string(),
1183 })
1184 .unwrap_err(),
1185 ProgramError::InvalidAccountData
1186 );
1187 }
1188
1189 #[test]
1190 fn variable_len_value_with_repetition() {
1191 let variable_len_1 = TestVariableLen {
1192 data: "Let's see if we can pack multiple variable length values".to_string(),
1193 };
1194 let tlv_size_1 = 8 + variable_len_1.data.len();
1195
1196 let variable_len_2 = TestVariableLen {
1197 data: "I think we can".to_string(),
1198 };
1199 let tlv_size_2 = 8 + variable_len_2.data.len();
1200
1201 let variable_len_3 = TestVariableLen {
1202 data: "In fact, I know we can!".to_string(),
1203 };
1204 let tlv_size_3 = 8 + variable_len_3.data.len();
1205
1206 let variable_len_4 = TestVariableLen {
1207 data: "How cool is this?".to_string(),
1208 };
1209 let tlv_size_4 = 8 + variable_len_4.data.len();
1210
1211 let account_size = get_base_len()
1212 + tlv_size_1
1213 + get_base_len()
1214 + tlv_size_2
1215 + get_base_len()
1216 + tlv_size_3
1217 + get_base_len()
1218 + tlv_size_4;
1219 let mut buffer = vec![0; account_size];
1220 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
1221
1222 let (_, repetition_number) = state.alloc::<TestVariableLen>(tlv_size_1, true).unwrap();
1223 state
1224 .pack_variable_len_value_with_repetition(&variable_len_1, repetition_number)
1225 .unwrap();
1226 assert_eq!(repetition_number, 0);
1227 assert_eq!(
1228 state
1229 .get_first_variable_len_value::<TestVariableLen>()
1230 .unwrap(),
1231 variable_len_1,
1232 );
1233
1234 let (_, repetition_number) = state.alloc::<TestVariableLen>(tlv_size_2, true).unwrap();
1235 state
1236 .pack_variable_len_value_with_repetition(&variable_len_2, repetition_number)
1237 .unwrap();
1238 assert_eq!(repetition_number, 1);
1239 assert_eq!(
1240 state
1241 .get_variable_len_value_with_repetition::<TestVariableLen>(repetition_number)
1242 .unwrap(),
1243 variable_len_2,
1244 );
1245
1246 let (_, repetition_number) = state.alloc::<TestVariableLen>(tlv_size_3, true).unwrap();
1247 state
1248 .pack_variable_len_value_with_repetition(&variable_len_3, repetition_number)
1249 .unwrap();
1250 assert_eq!(repetition_number, 2);
1251 assert_eq!(
1252 state
1253 .get_variable_len_value_with_repetition::<TestVariableLen>(repetition_number)
1254 .unwrap(),
1255 variable_len_3,
1256 );
1257
1258 let (_, repetition_number) = state.alloc::<TestVariableLen>(tlv_size_4, true).unwrap();
1259 state
1260 .pack_variable_len_value_with_repetition(&variable_len_4, repetition_number)
1261 .unwrap();
1262 assert_eq!(repetition_number, 3);
1263 assert_eq!(
1264 state
1265 .get_variable_len_value_with_repetition::<TestVariableLen>(repetition_number)
1266 .unwrap(),
1267 variable_len_4,
1268 );
1269 }
1270
1271 #[test]
1272 fn add_entry_mix_and_match() {
1273 let mut buffer = vec![];
1274
1275 let fixed_data = TestValue { data: [1; 32] };
1277 let tlv_size = get_base_len() + size_of::<TestValue>();
1278 buffer.extend(vec![0; tlv_size]);
1279 {
1280 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
1281 let (value, repetition_number) = state.init_value::<TestValue>(true).unwrap();
1282 value.data = fixed_data.data;
1283 assert_eq!(repetition_number, 0);
1284 assert_eq!(*value, fixed_data);
1285 }
1286
1287 let variable_data = TestVariableLen {
1289 data: "This is my first variable length entry!".to_string(),
1290 };
1291 let tlv_size = get_base_len() + 8 + variable_data.data.len();
1292 buffer.extend(vec![0; tlv_size]);
1293 {
1294 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
1295 let repetition_number = state
1296 .alloc_and_pack_variable_len_entry(&variable_data, true)
1297 .unwrap();
1298 let value = state
1299 .get_variable_len_value_with_repetition::<TestVariableLen>(repetition_number)
1300 .unwrap();
1301 assert_eq!(repetition_number, 0);
1302 assert_eq!(value, variable_data);
1303 }
1304
1305 let variable_data = TestVariableLen {
1307 data: "This is actually my second variable length entry!".to_string(),
1308 };
1309 let tlv_size = get_base_len() + 8 + variable_data.data.len();
1310 buffer.extend(vec![0; tlv_size]);
1311 {
1312 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
1313 let repetition_number = state
1314 .alloc_and_pack_variable_len_entry(&variable_data, true)
1315 .unwrap();
1316 let value = state
1317 .get_variable_len_value_with_repetition::<TestVariableLen>(repetition_number)
1318 .unwrap();
1319 assert_eq!(repetition_number, 1);
1320 assert_eq!(value, variable_data);
1321 }
1322
1323 let fixed_data = TestValue { data: [2; 32] };
1325 let tlv_size = get_base_len() + size_of::<TestValue>();
1326 buffer.extend(vec![0; tlv_size]);
1327 {
1328 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
1329 let (value, repetition_number) = state.init_value::<TestValue>(true).unwrap();
1330 value.data = fixed_data.data;
1331 assert_eq!(repetition_number, 1);
1332 assert_eq!(*value, fixed_data);
1333 }
1334
1335 let fixed_data = TestValue { data: [3; 32] };
1337 let tlv_size = get_base_len() + size_of::<TestValue>();
1338 buffer.extend(vec![0; tlv_size]);
1339 {
1340 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
1341 let (value, repetition_number) = state.init_value::<TestValue>(true).unwrap();
1342 value.data = fixed_data.data;
1343 assert_eq!(repetition_number, 2);
1344 assert_eq!(*value, fixed_data);
1345 }
1346
1347 let variable_data = TestVariableLen {
1349 data: "Wow! My third variable length entry!".to_string(),
1350 };
1351 let tlv_size = get_base_len() + 8 + variable_data.data.len();
1352 buffer.extend(vec![0; tlv_size]);
1353 {
1354 let mut state = TlvStateMut::unpack(&mut buffer).unwrap();
1355 let repetition_number = state
1356 .alloc_and_pack_variable_len_entry(&variable_data, true)
1357 .unwrap();
1358 let value = state
1359 .get_variable_len_value_with_repetition::<TestVariableLen>(repetition_number)
1360 .unwrap();
1361 assert_eq!(repetition_number, 2);
1362 assert_eq!(value, variable_data);
1363 }
1364 }
1365}