squads_multisig_program/utils/
small_vec.rs

1use std::io::{Read, Write};
2use std::marker::PhantomData;
3
4use anchor_lang::prelude::*;
5
6/// Concise serialization schema for vectors where the length can be represented
7/// by any type `L` (typically unsigned integer like `u8` or `u16`)
8/// that implements AnchorDeserialize and can be converted to `u32`.
9#[derive(Clone, Debug, Default)]
10pub struct SmallVec<L, T>(Vec<T>, PhantomData<L>);
11
12impl<L, T> SmallVec<L, T> {
13    pub fn len(&self) -> usize {
14        self.0.len()
15    }
16
17    pub fn is_empty(&self) -> bool {
18        self.0.is_empty()
19    }
20}
21
22impl<L, T> From<SmallVec<L, T>> for Vec<T> {
23    fn from(val: SmallVec<L, T>) -> Self {
24        val.0
25    }
26}
27
28impl<L, T> From<Vec<T>> for SmallVec<L, T> {
29    fn from(val: Vec<T>) -> Self {
30        Self(val, PhantomData)
31    }
32}
33
34impl<T: AnchorSerialize> AnchorSerialize for SmallVec<u8, T> {
35    fn serialize<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
36        let len = u8::try_from(self.len()).map_err(|_| std::io::ErrorKind::InvalidInput)?;
37        // Write the length of the vector as u8.
38        writer.write_all(&len.to_le_bytes())?;
39
40        // Write the vector elements.
41        serialize_slice(&self.0, writer)
42    }
43}
44
45impl<T: AnchorSerialize> AnchorSerialize for SmallVec<u16, T> {
46    fn serialize<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
47        let len = u16::try_from(self.len()).map_err(|_| std::io::ErrorKind::InvalidInput)?;
48        // Write the length of the vector as u16.
49        writer.write_all(&len.to_le_bytes())?;
50
51        // Write the vector elements.
52        serialize_slice(&self.0, writer)
53    }
54}
55
56impl<L, T> AnchorDeserialize for SmallVec<L, T>
57where
58    L: AnchorDeserialize + Into<u32>,
59    T: AnchorDeserialize,
60{
61    /// This implementation almost exactly matches standard implementation of
62    /// `Vec<T>::deserialize` except that it uses `L` instead of `u32` for the length,
63    /// and doesn't include `unsafe` code.
64    fn deserialize_reader<R: Read>(reader: &mut R) -> std::io::Result<Self> {
65        let len: u32 = L::deserialize_reader(reader)?.into();
66
67        let vec = if len == 0 {
68            Vec::new()
69        } else if let Some(vec_bytes) = T::vec_from_reader(len, reader)? {
70            vec_bytes
71        } else {
72            let mut result = Vec::with_capacity(hint::cautious::<T>(len));
73            for _ in 0..len {
74                result.push(T::deserialize_reader(reader)?);
75            }
76            result
77        };
78
79        Ok(SmallVec(vec, PhantomData))
80    }
81}
82
83// This is copy-pasted from borsh::de::hint;
84mod hint {
85    #[inline]
86    pub fn cautious<T>(hint: u32) -> usize {
87        let el_size = core::mem::size_of::<T>() as u32;
88        core::cmp::max(core::cmp::min(hint, 4096 / el_size), 1) as usize
89    }
90}
91
92/// Helper method that is used to serialize a slice of data (without the length marker).
93/// Copied from borsh::ser::serialize_slice.
94#[inline]
95fn serialize_slice<T: AnchorSerialize, W: Write>(
96    data: &[T],
97    writer: &mut W,
98) -> std::io::Result<()> {
99    if let Some(u8_slice) = T::u8_slice(data) {
100        writer.write_all(u8_slice)?;
101    } else {
102        for item in data {
103            item.serialize(writer)?;
104        }
105    }
106    Ok(())
107}
108
109#[cfg(test)]
110mod test {
111    use super::*;
112
113    mod deserialize {
114        use super::*;
115
116        #[test]
117        fn test_length_u8_type_u8() {
118            let mut input = &[
119                0x02, // len (2)
120                0x05, // vec[0]
121                0x09, // vec[1]
122            ][..];
123
124            let small_vec: SmallVec<u8, u8> = SmallVec::deserialize(&mut input).unwrap();
125
126            assert_eq!(small_vec.0, vec![5, 9]);
127        }
128
129        #[test]
130        fn test_length_u8_type_u32() {
131            let mut input = &[
132                0x02, // len (2)
133                0x05, 0x00, 0x00, 0x00, // vec[0]
134                0x09, 0x00, 0x00, 0x00, // vec[1]
135            ][..];
136
137            let small_vec: SmallVec<u8, u32> = SmallVec::deserialize(&mut input).unwrap();
138
139            assert_eq!(small_vec.0, vec![5, 9]);
140        }
141
142        #[test]
143        fn test_length_u8_type_pubkey() {
144            let pubkey1 = Pubkey::new_unique();
145            let pubkey2 = Pubkey::new_unique();
146            let mut input = &[
147                &[0x02], // len (2)
148                &pubkey1.try_to_vec().unwrap()[..],
149                &pubkey2.try_to_vec().unwrap()[..],
150            ]
151            .concat()[..];
152
153            let small_vec: SmallVec<u8, Pubkey> = SmallVec::deserialize(&mut input).unwrap();
154
155            assert_eq!(small_vec.0, vec![pubkey1, pubkey2]);
156        }
157
158        #[test]
159        fn test_length_u16_type_u8() {
160            let mut input = &[
161                0x02, 0x00, // len (2)
162                0x05, // vec[0]
163                0x09, // vec[1]
164            ][..];
165
166            let small_vec: SmallVec<u16, u8> = SmallVec::deserialize(&mut input).unwrap();
167
168            assert_eq!(small_vec.0, vec![5, 9]);
169        }
170
171        #[test]
172        fn test_length_u16_type_pubkey() {
173            let pubkey1 = Pubkey::new_unique();
174            let pubkey2 = Pubkey::new_unique();
175            let mut input = &[
176                &[0x02, 0x00], // len (2)
177                &pubkey1.try_to_vec().unwrap()[..],
178                &pubkey2.try_to_vec().unwrap()[..],
179            ]
180            .concat()[..];
181
182            let small_vec: SmallVec<u16, Pubkey> = SmallVec::deserialize(&mut input).unwrap();
183
184            assert_eq!(small_vec.0, vec![pubkey1, pubkey2]);
185        }
186    }
187
188    mod serialize {
189        use super::*;
190
191        #[test]
192        fn test_length_u8_type_u8() {
193            let small_vec = SmallVec::<u8, u8>::from(vec![3, 5]);
194
195            let mut output = vec![];
196            small_vec.serialize(&mut output).unwrap();
197
198            assert_eq!(
199                output,
200                vec![
201                    0x02, // len (2)
202                    0x03, // vec[0]
203                    0x05, // vec[1]
204                ]
205            );
206        }
207
208        #[test]
209        fn test_length_u8_type_u32() {
210            let small_vec = SmallVec::<u8, u32>::from(vec![3, 5]);
211
212            let mut output = vec![];
213            small_vec.serialize(&mut output).unwrap();
214
215            assert_eq!(
216                output,
217                vec![
218                    0x02, // len (2)
219                    0x03, 0x00, 0x00, 0x00, // vec[0]
220                    0x05, 0x00, 0x00, 0x00, // vec[1]
221                ]
222            );
223        }
224
225        #[test]
226        fn test_length_u8_type_pubkey() {
227            let pubkey1 = Pubkey::new_unique();
228            let pubkey2 = Pubkey::new_unique();
229            let small_vec = SmallVec::<u8, Pubkey>::from(vec![pubkey1, pubkey2]);
230
231            let mut output = vec![];
232            small_vec.serialize(&mut output).unwrap();
233
234            assert_eq!(
235                output,
236                [
237                    &[0x02], // len (2)
238                    &pubkey1.to_bytes()[..],
239                    &pubkey2.to_bytes()[..],
240                ]
241                .concat()[..]
242            );
243        }
244
245        #[test]
246        fn test_length_u16_type_u8() {
247            let small_vec = SmallVec::<u16, u8>::from(vec![3, 5]);
248
249            let mut output = vec![];
250            small_vec.serialize(&mut output).unwrap();
251
252            assert_eq!(
253                output,
254                vec![
255                    0x02, 0x00, // len (2)
256                    0x03, // vec[0]
257                    0x05, // vec[1]
258                ]
259            );
260        }
261
262        #[test]
263        fn test_length_u16_type_pubkey() {
264            let pubkey1 = Pubkey::new_unique();
265            let pubkey2 = Pubkey::new_unique();
266            let small_vec = SmallVec::<u16, Pubkey>::from(vec![pubkey1, pubkey2]);
267
268            let mut output = vec![];
269            small_vec.serialize(&mut output).unwrap();
270
271            assert_eq!(
272                output,
273                [
274                    &[0x02, 0x00], // len (2)
275                    &pubkey1.to_bytes()[..],
276                    &pubkey2.to_bytes()[..],
277                ]
278                .concat()[..]
279            );
280        }
281    }
282}