spl_stake_pool/
big_vec.rs

1//! Big vector type, used with vectors that can't be deserialized on-chain
2#![allow(clippy::arithmetic_side_effects)] // checked math involves too many compute units
3
4use {
5    arrayref::array_ref,
6    borsh::BorshDeserialize,
7    bytemuck::Pod,
8    solana_program::{program_error::ProgramError, program_memory::sol_memmove},
9    std::mem,
10};
11
12/// Contains easy to use utilities for a big vector of Borsh-compatible types,
13/// to avoid managing the entire struct on-chain and blow through stack limits.
14pub struct BigVec<'data> {
15    /// Underlying data buffer, pieces of which are serialized
16    pub data: &'data mut [u8],
17}
18
19const VEC_SIZE_BYTES: usize = 4;
20
21impl BigVec<'_> {
22    /// Get the length of the vector
23    pub fn len(&self) -> u32 {
24        let vec_len = array_ref![self.data, 0, VEC_SIZE_BYTES];
25        u32::from_le_bytes(*vec_len)
26    }
27
28    /// Find out if the vector has no contents
29    pub fn is_empty(&self) -> bool {
30        self.len() == 0
31    }
32
33    /// Retain all elements that match the provided function, discard all others
34    pub fn retain<T: Pod, F: Fn(&[u8]) -> bool>(
35        &mut self,
36        predicate: F,
37    ) -> Result<(), ProgramError> {
38        let mut vec_len = self.len();
39        let mut removals_found = 0;
40        let mut dst_start_index = 0;
41
42        let data_start_index = VEC_SIZE_BYTES;
43        let data_end_index =
44            data_start_index.saturating_add((vec_len as usize).saturating_mul(mem::size_of::<T>()));
45        for start_index in (data_start_index..data_end_index).step_by(mem::size_of::<T>()) {
46            let end_index = start_index + mem::size_of::<T>();
47            let slice = &self.data[start_index..end_index];
48            if !predicate(slice) {
49                let gap = removals_found * mem::size_of::<T>();
50                if removals_found > 0 {
51                    // In case the compute budget is ever bumped up, allowing us
52                    // to use this safe code instead:
53                    // self.data.copy_within(dst_start_index + gap..start_index, dst_start_index);
54                    unsafe {
55                        sol_memmove(
56                            self.data[dst_start_index..start_index - gap].as_mut_ptr(),
57                            self.data[dst_start_index + gap..start_index].as_mut_ptr(),
58                            start_index - gap - dst_start_index,
59                        );
60                    }
61                }
62                dst_start_index = start_index - gap;
63                removals_found += 1;
64                vec_len -= 1;
65            }
66        }
67
68        // final memmove
69        if removals_found > 0 {
70            let gap = removals_found * mem::size_of::<T>();
71            // In case the compute budget is ever bumped up, allowing us
72            // to use this safe code instead:
73            //    self.data.copy_within(
74            //        dst_start_index + gap..data_end_index,
75            //        dst_start_index,
76            //    );
77            unsafe {
78                sol_memmove(
79                    self.data[dst_start_index..data_end_index - gap].as_mut_ptr(),
80                    self.data[dst_start_index + gap..data_end_index].as_mut_ptr(),
81                    data_end_index - gap - dst_start_index,
82                );
83            }
84        }
85
86        let vec_len_ref = &mut self.data[0..VEC_SIZE_BYTES];
87        borsh::to_writer(vec_len_ref, &vec_len)?;
88
89        Ok(())
90    }
91
92    /// Extracts a slice of the data types
93    pub fn deserialize_mut_slice<T: Pod>(
94        &mut self,
95        skip: usize,
96        len: usize,
97    ) -> Result<&mut [T], ProgramError> {
98        let vec_len = self.len();
99        let last_item_index = skip
100            .checked_add(len)
101            .ok_or(ProgramError::AccountDataTooSmall)?;
102        if last_item_index > vec_len as usize {
103            return Err(ProgramError::AccountDataTooSmall);
104        }
105
106        let start_index = VEC_SIZE_BYTES.saturating_add(skip.saturating_mul(mem::size_of::<T>()));
107        let end_index = start_index.saturating_add(len.saturating_mul(mem::size_of::<T>()));
108        bytemuck::try_cast_slice_mut(&mut self.data[start_index..end_index])
109            .map_err(|_| ProgramError::InvalidAccountData)
110    }
111
112    /// Extracts a slice of the data types
113    pub fn deserialize_slice<T: Pod>(&self, skip: usize, len: usize) -> Result<&[T], ProgramError> {
114        let vec_len = self.len();
115        let last_item_index = skip
116            .checked_add(len)
117            .ok_or(ProgramError::AccountDataTooSmall)?;
118        if last_item_index > vec_len as usize {
119            return Err(ProgramError::AccountDataTooSmall);
120        }
121
122        let start_index = VEC_SIZE_BYTES.saturating_add(skip.saturating_mul(mem::size_of::<T>()));
123        let end_index = start_index.saturating_add(len.saturating_mul(mem::size_of::<T>()));
124        bytemuck::try_cast_slice(&self.data[start_index..end_index])
125            .map_err(|_| ProgramError::InvalidAccountData)
126    }
127
128    /// Add new element to the end
129    pub fn push<T: Pod>(&mut self, element: T) -> Result<(), ProgramError> {
130        let vec_len_ref = &mut self.data[0..VEC_SIZE_BYTES];
131        let mut vec_len = u32::try_from_slice(vec_len_ref)?;
132
133        let start_index = VEC_SIZE_BYTES + vec_len as usize * mem::size_of::<T>();
134        let end_index = start_index + mem::size_of::<T>();
135
136        vec_len += 1;
137        borsh::to_writer(vec_len_ref, &vec_len)?;
138
139        if self.data.len() < end_index {
140            return Err(ProgramError::AccountDataTooSmall);
141        }
142        let element_ref = bytemuck::try_from_bytes_mut(
143            &mut self.data[start_index..start_index + mem::size_of::<T>()],
144        )
145        .map_err(|_| ProgramError::InvalidAccountData)?;
146        *element_ref = element;
147        Ok(())
148    }
149
150    /// Find matching data in the array
151    pub fn find<T: Pod, F: Fn(&[u8]) -> bool>(&self, predicate: F) -> Option<&T> {
152        let len = self.len() as usize;
153        let mut current = 0;
154        let mut current_index = VEC_SIZE_BYTES;
155        while current != len {
156            let end_index = current_index + mem::size_of::<T>();
157            let current_slice = &self.data[current_index..end_index];
158            if predicate(current_slice) {
159                return Some(bytemuck::from_bytes(current_slice));
160            }
161            current_index = end_index;
162            current += 1;
163        }
164        None
165    }
166
167    /// Find matching data in the array
168    pub fn find_mut<T: Pod, F: Fn(&[u8]) -> bool>(&mut self, predicate: F) -> Option<&mut T> {
169        let len = self.len() as usize;
170        let mut current = 0;
171        let mut current_index = VEC_SIZE_BYTES;
172        while current != len {
173            let end_index = current_index + mem::size_of::<T>();
174            let current_slice = &self.data[current_index..end_index];
175            if predicate(current_slice) {
176                return Some(bytemuck::from_bytes_mut(
177                    &mut self.data[current_index..end_index],
178                ));
179            }
180            current_index = end_index;
181            current += 1;
182        }
183        None
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use {super::*, bytemuck::Zeroable};
190
191    #[repr(C)]
192    #[derive(Debug, Copy, Clone, PartialEq, Pod, Zeroable)]
193    struct TestStruct {
194        value: [u8; 8],
195    }
196
197    impl TestStruct {
198        fn new(value: u8) -> Self {
199            let value = [value, 0, 0, 0, 0, 0, 0, 0];
200            Self { value }
201        }
202    }
203
204    fn from_slice<'data>(data: &'data mut [u8], vec: &[u8]) -> BigVec<'data> {
205        let mut big_vec = BigVec { data };
206        for element in vec {
207            big_vec.push(TestStruct::new(*element)).unwrap();
208        }
209        big_vec
210    }
211
212    fn check_big_vec_eq(big_vec: &BigVec, slice: &[u8]) {
213        assert!(big_vec
214            .deserialize_slice::<TestStruct>(0, big_vec.len() as usize)
215            .unwrap()
216            .iter()
217            .map(|x| &x.value[0])
218            .zip(slice.iter())
219            .all(|(a, b)| a == b));
220    }
221
222    #[test]
223    fn push() {
224        let mut data = [0u8; 4 + 8 * 3];
225        let mut v = BigVec { data: &mut data };
226        v.push(TestStruct::new(1)).unwrap();
227        check_big_vec_eq(&v, &[1]);
228        v.push(TestStruct::new(2)).unwrap();
229        check_big_vec_eq(&v, &[1, 2]);
230        v.push(TestStruct::new(3)).unwrap();
231        check_big_vec_eq(&v, &[1, 2, 3]);
232        assert_eq!(
233            v.push(TestStruct::new(4)).unwrap_err(),
234            ProgramError::AccountDataTooSmall
235        );
236    }
237
238    #[test]
239    fn retain() {
240        fn mod_2_predicate(data: &[u8]) -> bool {
241            u64::try_from_slice(data).unwrap() % 2 == 0
242        }
243
244        let mut data = [0u8; 4 + 8 * 4];
245        let mut v = from_slice(&mut data, &[1, 2, 3, 4]);
246        v.retain::<TestStruct, _>(mod_2_predicate).unwrap();
247        check_big_vec_eq(&v, &[2, 4]);
248    }
249
250    fn find_predicate(a: &[u8], b: u8) -> bool {
251        if a.len() != 8 {
252            false
253        } else {
254            a[0] == b
255        }
256    }
257
258    #[test]
259    fn find() {
260        let mut data = [0u8; 4 + 8 * 4];
261        let v = from_slice(&mut data, &[1, 2, 3, 4]);
262        assert_eq!(
263            v.find::<TestStruct, _>(|x| find_predicate(x, 1)),
264            Some(&TestStruct::new(1))
265        );
266        assert_eq!(
267            v.find::<TestStruct, _>(|x| find_predicate(x, 4)),
268            Some(&TestStruct::new(4))
269        );
270        assert_eq!(v.find::<TestStruct, _>(|x| find_predicate(x, 5)), None);
271    }
272
273    #[test]
274    fn find_mut() {
275        let mut data = [0u8; 4 + 8 * 4];
276        let mut v = from_slice(&mut data, &[1, 2, 3, 4]);
277        let test_struct = v
278            .find_mut::<TestStruct, _>(|x| find_predicate(x, 1))
279            .unwrap();
280        test_struct.value = [0; 8];
281        check_big_vec_eq(&v, &[0, 2, 3, 4]);
282        assert_eq!(v.find_mut::<TestStruct, _>(|x| find_predicate(x, 5)), None);
283    }
284
285    #[test]
286    fn deserialize_mut_slice() {
287        let mut data = [0u8; 4 + 8 * 4];
288        let mut v = from_slice(&mut data, &[1, 2, 3, 4]);
289        let slice = v.deserialize_mut_slice::<TestStruct>(1, 2).unwrap();
290        slice[0].value[0] = 10;
291        slice[1].value[0] = 11;
292        check_big_vec_eq(&v, &[1, 10, 11, 4]);
293        assert_eq!(
294            v.deserialize_mut_slice::<TestStruct>(1, 4).unwrap_err(),
295            ProgramError::AccountDataTooSmall
296        );
297        assert_eq!(
298            v.deserialize_mut_slice::<TestStruct>(4, 1).unwrap_err(),
299            ProgramError::AccountDataTooSmall
300        );
301    }
302}