spl_stake_pool/
big_vec.rs

1//! Big vector type, used with vectors that can't be serde'd
2#![allow(clippy::integer_arithmetic)] // checked math involves too many compute units
3
4use {
5    arrayref::array_ref,
6    borsh::{BorshDeserialize, BorshSerialize},
7    solana_program::{
8        program_error::ProgramError, program_memory::sol_memmove, program_pack::Pack,
9    },
10    std::marker::PhantomData,
11};
12
13/// Contains easy to use utilities for a big vector of Borsh-compatible types,
14/// to avoid managing the entire struct on-chain and blow through stack limits.
15pub struct BigVec<'data> {
16    /// Underlying data buffer, pieces of which are serialized
17    pub data: &'data mut [u8],
18}
19
20const VEC_SIZE_BYTES: usize = 4;
21
22impl<'data> BigVec<'data> {
23    /// Get the length of the vector
24    pub fn len(&self) -> u32 {
25        let vec_len = array_ref![self.data, 0, VEC_SIZE_BYTES];
26        u32::from_le_bytes(*vec_len)
27    }
28
29    /// Find out if the vector has no contents (as demanded by clippy)
30    pub fn is_empty(&self) -> bool {
31        self.len() == 0
32    }
33
34    /// Retain all elements that match the provided function, discard all others
35    pub fn retain<T: Pack>(&mut self, predicate: fn(&[u8]) -> bool) -> Result<(), ProgramError> {
36        let mut vec_len = self.len();
37        let mut removals_found = 0;
38        let mut dst_start_index = 0;
39
40        let data_start_index = VEC_SIZE_BYTES;
41        let data_end_index =
42            data_start_index.saturating_add((vec_len as usize).saturating_mul(T::LEN));
43        for start_index in (data_start_index..data_end_index).step_by(T::LEN) {
44            let end_index = start_index + T::LEN;
45            let slice = &self.data[start_index..end_index];
46            if !predicate(slice) {
47                let gap = removals_found * T::LEN;
48                if removals_found > 0 {
49                    // In case the compute budget is ever bumped up, allowing us
50                    // to use this safe code instead:
51                    // self.data.copy_within(dst_start_index + gap..start_index, dst_start_index);
52                    unsafe {
53                        sol_memmove(
54                            self.data[dst_start_index..start_index - gap].as_mut_ptr(),
55                            self.data[dst_start_index + gap..start_index].as_mut_ptr(),
56                            start_index - gap - dst_start_index,
57                        );
58                    }
59                }
60                dst_start_index = start_index - gap;
61                removals_found += 1;
62                vec_len -= 1;
63            }
64        }
65
66        // final memmove
67        if removals_found > 0 {
68            let gap = removals_found * T::LEN;
69            // In case the compute budget is ever bumped up, allowing us
70            // to use this safe code instead:
71            //self.data.copy_within(dst_start_index + gap..data_end_index, dst_start_index);
72            unsafe {
73                sol_memmove(
74                    self.data[dst_start_index..data_end_index - gap].as_mut_ptr(),
75                    self.data[dst_start_index + gap..data_end_index].as_mut_ptr(),
76                    data_end_index - gap - dst_start_index,
77                );
78            }
79        }
80
81        let mut vec_len_ref = &mut self.data[0..VEC_SIZE_BYTES];
82        vec_len.serialize(&mut vec_len_ref)?;
83
84        Ok(())
85    }
86
87    /// Extracts a slice of the data types
88    pub fn deserialize_mut_slice<T: Pack>(
89        &mut self,
90        skip: usize,
91        len: usize,
92    ) -> Result<Vec<&'data mut T>, ProgramError> {
93        let vec_len = self.len();
94        let last_item_index = skip
95            .checked_add(len)
96            .ok_or(ProgramError::AccountDataTooSmall)?;
97        if last_item_index > vec_len as usize {
98            return Err(ProgramError::AccountDataTooSmall);
99        }
100
101        let start_index = VEC_SIZE_BYTES.saturating_add(skip.saturating_mul(T::LEN));
102        let end_index = start_index.saturating_add(len.saturating_mul(T::LEN));
103        let mut deserialized = vec![];
104        for slice in self.data[start_index..end_index].chunks_exact_mut(T::LEN) {
105            deserialized.push(unsafe { &mut *(slice.as_ptr() as *mut T) });
106        }
107        Ok(deserialized)
108    }
109
110    /// Add new element to the end
111    pub fn push<T: Pack>(&mut self, element: T) -> Result<(), ProgramError> {
112        let mut vec_len_ref = &mut self.data[0..VEC_SIZE_BYTES];
113        let mut vec_len = u32::try_from_slice(vec_len_ref)?;
114
115        let start_index = VEC_SIZE_BYTES + vec_len as usize * T::LEN;
116        let end_index = start_index + T::LEN;
117
118        vec_len += 1;
119        vec_len.serialize(&mut vec_len_ref)?;
120
121        if self.data.len() < end_index {
122            return Err(ProgramError::AccountDataTooSmall);
123        }
124        let element_ref = &mut self.data[start_index..start_index + T::LEN];
125        element.pack_into_slice(element_ref);
126        Ok(())
127    }
128
129    /// Get an iterator for the type provided
130    pub fn iter<'vec, T: Pack>(&'vec self) -> Iter<'data, 'vec, T> {
131        Iter {
132            len: self.len() as usize,
133            current: 0,
134            current_index: VEC_SIZE_BYTES,
135            inner: self,
136            phantom: PhantomData,
137        }
138    }
139
140    /// Get a mutable iterator for the type provided
141    pub fn iter_mut<'vec, T: Pack>(&'vec mut self) -> IterMut<'data, 'vec, T> {
142        IterMut {
143            len: self.len() as usize,
144            current: 0,
145            current_index: VEC_SIZE_BYTES,
146            inner: self,
147            phantom: PhantomData,
148        }
149    }
150
151    /// Find matching data in the array
152    pub fn find<T: Pack, F: Fn(&[u8]) -> bool>(&self, predicate: F) -> Option<&T> {
153        let len = self.len() as usize;
154        let mut current = 0;
155        let mut current_index = VEC_SIZE_BYTES;
156        while current != len {
157            let end_index = current_index + T::LEN;
158            let current_slice = &self.data[current_index..end_index];
159            if predicate(current_slice) {
160                return Some(unsafe { &*(current_slice.as_ptr() as *const T) });
161            }
162            current_index = end_index;
163            current += 1;
164        }
165        None
166    }
167
168    /// Find matching data in the array
169    pub fn find_mut<T: Pack, F: Fn(&[u8]) -> bool>(&mut self, predicate: F) -> Option<&mut T> {
170        let len = self.len() as usize;
171        let mut current = 0;
172        let mut current_index = VEC_SIZE_BYTES;
173        while current != len {
174            let end_index = current_index + T::LEN;
175            let current_slice = &self.data[current_index..end_index];
176            if predicate(current_slice) {
177                return Some(unsafe { &mut *(current_slice.as_ptr() as *mut T) });
178            }
179            current_index = end_index;
180            current += 1;
181        }
182        None
183    }
184}
185
186/// Iterator wrapper over a BigVec
187pub struct Iter<'data, 'vec, T> {
188    len: usize,
189    current: usize,
190    current_index: usize,
191    inner: &'vec BigVec<'data>,
192    phantom: PhantomData<T>,
193}
194
195impl<'data, 'vec, T: Pack + 'data> Iterator for Iter<'data, 'vec, T> {
196    type Item = &'data T;
197
198    fn next(&mut self) -> Option<Self::Item> {
199        if self.current == self.len {
200            None
201        } else {
202            let end_index = self.current_index + T::LEN;
203            let value = Some(unsafe {
204                &*(self.inner.data[self.current_index..end_index].as_ptr() as *const T)
205            });
206            self.current += 1;
207            self.current_index = end_index;
208            value
209        }
210    }
211}
212
213/// Iterator wrapper over a BigVec
214pub struct IterMut<'data, 'vec, T> {
215    len: usize,
216    current: usize,
217    current_index: usize,
218    inner: &'vec mut BigVec<'data>,
219    phantom: PhantomData<T>,
220}
221
222impl<'data, 'vec, T: Pack + 'data> Iterator for IterMut<'data, 'vec, T> {
223    type Item = &'data mut T;
224
225    fn next(&mut self) -> Option<Self::Item> {
226        if self.current == self.len {
227            None
228        } else {
229            let end_index = self.current_index + T::LEN;
230            let value = Some(unsafe {
231                &mut *(self.inner.data[self.current_index..end_index].as_ptr() as *mut T)
232            });
233            self.current += 1;
234            self.current_index = end_index;
235            value
236        }
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use {super::*, solana_program::program_pack::Sealed};
243
244    #[derive(Debug, PartialEq)]
245    struct TestStruct {
246        value: u64,
247    }
248
249    impl Sealed for TestStruct {}
250
251    impl Pack for TestStruct {
252        const LEN: usize = 8;
253        fn pack_into_slice(&self, data: &mut [u8]) {
254            let mut data = data;
255            self.value.serialize(&mut data).unwrap();
256        }
257        fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
258            Ok(TestStruct {
259                value: u64::try_from_slice(src).unwrap(),
260            })
261        }
262    }
263
264    impl TestStruct {
265        fn new(value: u64) -> Self {
266            Self { value }
267        }
268    }
269
270    fn from_slice<'data, 'other>(data: &'data mut [u8], vec: &'other [u64]) -> BigVec<'data> {
271        let mut big_vec = BigVec { data };
272        for element in vec {
273            big_vec.push(TestStruct::new(*element)).unwrap();
274        }
275        big_vec
276    }
277
278    fn check_big_vec_eq(big_vec: &BigVec, slice: &[u64]) {
279        assert!(big_vec
280            .iter::<TestStruct>()
281            .map(|x| &x.value)
282            .zip(slice.iter())
283            .all(|(a, b)| a == b));
284    }
285
286    #[test]
287    fn push() {
288        let mut data = [0u8; 4 + 8 * 3];
289        let mut v = BigVec { data: &mut data };
290        v.push(TestStruct::new(1)).unwrap();
291        check_big_vec_eq(&v, &[1]);
292        v.push(TestStruct::new(2)).unwrap();
293        check_big_vec_eq(&v, &[1, 2]);
294        v.push(TestStruct::new(3)).unwrap();
295        check_big_vec_eq(&v, &[1, 2, 3]);
296        assert_eq!(
297            v.push(TestStruct::new(4)).unwrap_err(),
298            ProgramError::AccountDataTooSmall
299        );
300    }
301
302    #[test]
303    fn retain() {
304        fn mod_2_predicate(data: &[u8]) -> bool {
305            u64::try_from_slice(data).unwrap() % 2 == 0
306        }
307
308        let mut data = [0u8; 4 + 8 * 4];
309        let mut v = from_slice(&mut data, &[1, 2, 3, 4]);
310        v.retain::<TestStruct>(mod_2_predicate).unwrap();
311        check_big_vec_eq(&v, &[2, 4]);
312    }
313
314    fn find_predicate(a: &[u8], b: u64) -> bool {
315        if a.len() != 8 {
316            false
317        } else {
318            u64::try_from_slice(&a[0..8]).unwrap() == b
319        }
320    }
321
322    #[test]
323    fn find() {
324        let mut data = [0u8; 4 + 8 * 4];
325        let v = from_slice(&mut data, &[1, 2, 3, 4]);
326        assert_eq!(
327            v.find::<TestStruct, _>(|x| find_predicate(x, 1)),
328            Some(&TestStruct::new(1))
329        );
330        assert_eq!(
331            v.find::<TestStruct, _>(|x| find_predicate(x, 4)),
332            Some(&TestStruct::new(4))
333        );
334        assert_eq!(v.find::<TestStruct, _>(|x| find_predicate(x, 5)), None);
335    }
336
337    #[test]
338    fn find_mut() {
339        let mut data = [0u8; 4 + 8 * 4];
340        let mut v = from_slice(&mut data, &[1, 2, 3, 4]);
341        let mut test_struct = v
342            .find_mut::<TestStruct, _>(|x| find_predicate(x, 1))
343            .unwrap();
344        test_struct.value = 0;
345        check_big_vec_eq(&v, &[0, 2, 3, 4]);
346        assert_eq!(v.find_mut::<TestStruct, _>(|x| find_predicate(x, 5)), None);
347    }
348
349    #[test]
350    fn deserialize_mut_slice() {
351        let mut data = [0u8; 4 + 8 * 4];
352        let mut v = from_slice(&mut data, &[1, 2, 3, 4]);
353        let mut slice = v.deserialize_mut_slice::<TestStruct>(1, 2).unwrap();
354        slice[0].value = 10;
355        slice[1].value = 11;
356        check_big_vec_eq(&v, &[1, 10, 11, 4]);
357        assert_eq!(
358            v.deserialize_mut_slice::<TestStruct>(1, 4).unwrap_err(),
359            ProgramError::AccountDataTooSmall
360        );
361        assert_eq!(
362            v.deserialize_mut_slice::<TestStruct>(4, 1).unwrap_err(),
363            ProgramError::AccountDataTooSmall
364        );
365    }
366}