Skip to main content

steel/account/
deserialize.rs

1use bytemuck::Pod;
2use solana_program::program_error::ProgramError;
3
4pub trait Discriminator {
5    fn discriminator() -> u8;
6}
7
8pub trait AccountDeserialize {
9    fn try_from_bytes(data: &[u8]) -> Result<&Self, ProgramError>;
10    fn try_from_bytes_mut(data: &mut [u8]) -> Result<&mut Self, ProgramError>;
11}
12
13impl<T> AccountDeserialize for T
14where
15    T: Discriminator + Pod,
16{
17    fn try_from_bytes(data: &[u8]) -> Result<&Self, ProgramError> {
18        if Self::discriminator().ne(&data[0]) {
19            return Err(solana_program::program_error::ProgramError::InvalidAccountData);
20        }
21        bytemuck::try_from_bytes::<Self>(&data[8..]).or(Err(
22            solana_program::program_error::ProgramError::InvalidAccountData,
23        ))
24    }
25
26    fn try_from_bytes_mut(data: &mut [u8]) -> Result<&mut Self, ProgramError> {
27        if Self::discriminator().ne(&data[0]) {
28            return Err(solana_program::program_error::ProgramError::InvalidAccountData);
29        }
30        bytemuck::try_from_bytes_mut::<Self>(&mut data[8..]).or(Err(
31            solana_program::program_error::ProgramError::InvalidAccountData,
32        ))
33    }
34}
35
36/// Account data is sometimes stored via a header and body type,
37/// where the former resolves the type of the latter (e.g. merkle trees with a generic size const).
38/// This trait parses a header type from the first N bytes of some data, and returns the remaining
39/// bytes, which are then available for further processing.
40///
41/// See module-level tests for example usage.
42pub trait AccountHeaderDeserialize {
43    fn try_header_from_bytes(data: &[u8]) -> Result<(&Self, &[u8]), ProgramError>;
44    fn try_header_from_bytes_mut(data: &mut [u8]) -> Result<(&mut Self, &mut [u8]), ProgramError>;
45}
46
47impl<T> AccountHeaderDeserialize for T
48where
49    T: Discriminator + Pod,
50{
51    fn try_header_from_bytes(data: &[u8]) -> Result<(&Self, &[u8]), ProgramError> {
52        if Self::discriminator().ne(&data[0]) {
53            return Err(solana_program::program_error::ProgramError::InvalidAccountData);
54        }
55        let (prefix, remainder) = data[8..].split_at(std::mem::size_of::<T>());
56        Ok((
57            bytemuck::try_from_bytes::<Self>(prefix).or(Err(
58                solana_program::program_error::ProgramError::InvalidAccountData,
59            ))?,
60            remainder,
61        ))
62    }
63
64    fn try_header_from_bytes_mut(data: &mut [u8]) -> Result<(&mut Self, &mut [u8]), ProgramError> {
65        let (prefix, remainder) = data[8..].split_at_mut(std::mem::size_of::<T>());
66        Ok((
67            bytemuck::try_from_bytes_mut::<Self>(prefix).or(Err(
68                solana_program::program_error::ProgramError::InvalidAccountData,
69            ))?,
70            remainder,
71        ))
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use crate::AccountDeserialize;
78
79    use super::*;
80    use bytemuck::{Pod, Zeroable};
81
82    #[repr(C)]
83    #[derive(Copy, Clone)]
84    struct GenericallySizedType<const N: usize> {
85        field: [u32; N],
86    }
87
88    unsafe impl<const N: usize> Zeroable for GenericallySizedType<N> {}
89    unsafe impl<const N: usize> Pod for GenericallySizedType<N> {}
90
91    #[repr(C)]
92    #[derive(Copy, Clone, Zeroable, Pod)]
93    struct GenericallySizedTypeHeader {
94        field_len: u64,
95    }
96
97    impl Discriminator for GenericallySizedTypeHeader {
98        fn discriminator() -> u8 {
99            0
100        }
101    }
102
103    #[test]
104    fn account_headers() {
105        let mut data = [0u8; 32];
106        data[8] = 4;
107        data[16] = 5;
108        let (_foo_header, foo) = GenericallySizedTypeHeader::try_header_from_bytes(&data)
109            .map(|(header, remainder)| {
110                let foo = match header.field_len {
111                    4 => bytemuck::try_from_bytes::<GenericallySizedType<4>>(remainder).unwrap(),
112                    x => panic!("{}", format!("unknown field len, {x}")),
113                };
114                (header, foo)
115            })
116            .unwrap();
117        assert_eq!(5, foo.field[0]);
118    }
119
120    #[repr(C)]
121    #[derive(Copy, Clone, Zeroable, Pod)]
122    struct TestType {
123        field0: u64,
124        field1: u64,
125    }
126
127    impl Discriminator for TestType {
128        fn discriminator() -> u8 {
129            7
130        }
131    }
132
133    #[test]
134    fn account_deserialize() {
135        let mut data = [0u8; 24];
136        data[0] = 7;
137        data[8] = 42;
138        data[16] = 43;
139        let foo = TestType::try_from_bytes(&data).unwrap();
140        assert_eq!(42, foo.field0);
141        assert_eq!(43, foo.field1);
142    }
143}