1#![allow(clippy::integer_arithmetic)] use {
5 arrayref::array_ref,
6 borsh::{BorshDeserialize, BorshSerialize},
7 solana_program::{
8 program_error::ProgramError, program_memory::sol_memmove, program_pack::Pack,
9 },
10 std::marker::PhantomData,
11};
12
13pub struct BigVec<'data> {
16 pub data: &'data mut [u8],
18}
19
20const VEC_SIZE_BYTES: usize = 4;
21
22impl<'data> BigVec<'data> {
23 pub fn len(&self) -> u32 {
25 let vec_len = array_ref![self.data, 0, VEC_SIZE_BYTES];
26 u32::from_le_bytes(*vec_len)
27 }
28
29 pub fn is_empty(&self) -> bool {
31 self.len() == 0
32 }
33
34 pub fn retain<T: Pack>(&mut self, predicate: fn(&[u8]) -> bool) -> Result<(), ProgramError> {
36 let mut vec_len = self.len();
37 let mut removals_found = 0;
38 let mut dst_start_index = 0;
39
40 let data_start_index = VEC_SIZE_BYTES;
41 let data_end_index =
42 data_start_index.saturating_add((vec_len as usize).saturating_mul(T::LEN));
43 for start_index in (data_start_index..data_end_index).step_by(T::LEN) {
44 let end_index = start_index + T::LEN;
45 let slice = &self.data[start_index..end_index];
46 if !predicate(slice) {
47 let gap = removals_found * T::LEN;
48 if removals_found > 0 {
49 unsafe {
53 sol_memmove(
54 self.data[dst_start_index..start_index - gap].as_mut_ptr(),
55 self.data[dst_start_index + gap..start_index].as_mut_ptr(),
56 start_index - gap - dst_start_index,
57 );
58 }
59 }
60 dst_start_index = start_index - gap;
61 removals_found += 1;
62 vec_len -= 1;
63 }
64 }
65
66 if removals_found > 0 {
68 let gap = removals_found * T::LEN;
69 unsafe {
73 sol_memmove(
74 self.data[dst_start_index..data_end_index - gap].as_mut_ptr(),
75 self.data[dst_start_index + gap..data_end_index].as_mut_ptr(),
76 data_end_index - gap - dst_start_index,
77 );
78 }
79 }
80
81 let mut vec_len_ref = &mut self.data[0..VEC_SIZE_BYTES];
82 vec_len.serialize(&mut vec_len_ref)?;
83
84 Ok(())
85 }
86
87 pub fn deserialize_mut_slice<T: Pack>(
89 &mut self,
90 skip: usize,
91 len: usize,
92 ) -> Result<Vec<&'data mut T>, ProgramError> {
93 let vec_len = self.len();
94 let last_item_index = skip
95 .checked_add(len)
96 .ok_or(ProgramError::AccountDataTooSmall)?;
97 if last_item_index > vec_len as usize {
98 return Err(ProgramError::AccountDataTooSmall);
99 }
100
101 let start_index = VEC_SIZE_BYTES.saturating_add(skip.saturating_mul(T::LEN));
102 let end_index = start_index.saturating_add(len.saturating_mul(T::LEN));
103 let mut deserialized = vec![];
104 for slice in self.data[start_index..end_index].chunks_exact_mut(T::LEN) {
105 deserialized.push(unsafe { &mut *(slice.as_ptr() as *mut T) });
106 }
107 Ok(deserialized)
108 }
109
110 pub fn push<T: Pack>(&mut self, element: T) -> Result<(), ProgramError> {
112 let mut vec_len_ref = &mut self.data[0..VEC_SIZE_BYTES];
113 let mut vec_len = u32::try_from_slice(vec_len_ref)?;
114
115 let start_index = VEC_SIZE_BYTES + vec_len as usize * T::LEN;
116 let end_index = start_index + T::LEN;
117
118 vec_len += 1;
119 vec_len.serialize(&mut vec_len_ref)?;
120
121 if self.data.len() < end_index {
122 return Err(ProgramError::AccountDataTooSmall);
123 }
124 let element_ref = &mut self.data[start_index..start_index + T::LEN];
125 element.pack_into_slice(element_ref);
126 Ok(())
127 }
128
129 pub fn iter<'vec, T: Pack>(&'vec self) -> Iter<'data, 'vec, T> {
131 Iter {
132 len: self.len() as usize,
133 current: 0,
134 current_index: VEC_SIZE_BYTES,
135 inner: self,
136 phantom: PhantomData,
137 }
138 }
139
140 pub fn iter_mut<'vec, T: Pack>(&'vec mut self) -> IterMut<'data, 'vec, T> {
142 IterMut {
143 len: self.len() as usize,
144 current: 0,
145 current_index: VEC_SIZE_BYTES,
146 inner: self,
147 phantom: PhantomData,
148 }
149 }
150
151 pub fn find<T: Pack, F: Fn(&[u8]) -> bool>(&self, predicate: F) -> Option<&T> {
153 let len = self.len() as usize;
154 let mut current = 0;
155 let mut current_index = VEC_SIZE_BYTES;
156 while current != len {
157 let end_index = current_index + T::LEN;
158 let current_slice = &self.data[current_index..end_index];
159 if predicate(current_slice) {
160 return Some(unsafe { &*(current_slice.as_ptr() as *const T) });
161 }
162 current_index = end_index;
163 current += 1;
164 }
165 None
166 }
167
168 pub fn find_mut<T: Pack, F: Fn(&[u8]) -> bool>(&mut self, predicate: F) -> Option<&mut T> {
170 let len = self.len() as usize;
171 let mut current = 0;
172 let mut current_index = VEC_SIZE_BYTES;
173 while current != len {
174 let end_index = current_index + T::LEN;
175 let current_slice = &self.data[current_index..end_index];
176 if predicate(current_slice) {
177 return Some(unsafe { &mut *(current_slice.as_ptr() as *mut T) });
178 }
179 current_index = end_index;
180 current += 1;
181 }
182 None
183 }
184}
185
186pub struct Iter<'data, 'vec, T> {
188 len: usize,
189 current: usize,
190 current_index: usize,
191 inner: &'vec BigVec<'data>,
192 phantom: PhantomData<T>,
193}
194
195impl<'data, 'vec, T: Pack + 'data> Iterator for Iter<'data, 'vec, T> {
196 type Item = &'data T;
197
198 fn next(&mut self) -> Option<Self::Item> {
199 if self.current == self.len {
200 None
201 } else {
202 let end_index = self.current_index + T::LEN;
203 let value = Some(unsafe {
204 &*(self.inner.data[self.current_index..end_index].as_ptr() as *const T)
205 });
206 self.current += 1;
207 self.current_index = end_index;
208 value
209 }
210 }
211}
212
213pub struct IterMut<'data, 'vec, T> {
215 len: usize,
216 current: usize,
217 current_index: usize,
218 inner: &'vec mut BigVec<'data>,
219 phantom: PhantomData<T>,
220}
221
222impl<'data, 'vec, T: Pack + 'data> Iterator for IterMut<'data, 'vec, T> {
223 type Item = &'data mut T;
224
225 fn next(&mut self) -> Option<Self::Item> {
226 if self.current == self.len {
227 None
228 } else {
229 let end_index = self.current_index + T::LEN;
230 let value = Some(unsafe {
231 &mut *(self.inner.data[self.current_index..end_index].as_ptr() as *mut T)
232 });
233 self.current += 1;
234 self.current_index = end_index;
235 value
236 }
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use {super::*, solana_program::program_pack::Sealed};
243
244 #[derive(Debug, PartialEq)]
245 struct TestStruct {
246 value: u64,
247 }
248
249 impl Sealed for TestStruct {}
250
251 impl Pack for TestStruct {
252 const LEN: usize = 8;
253 fn pack_into_slice(&self, data: &mut [u8]) {
254 let mut data = data;
255 self.value.serialize(&mut data).unwrap();
256 }
257 fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
258 Ok(TestStruct {
259 value: u64::try_from_slice(src).unwrap(),
260 })
261 }
262 }
263
264 impl TestStruct {
265 fn new(value: u64) -> Self {
266 Self { value }
267 }
268 }
269
270 fn from_slice<'data, 'other>(data: &'data mut [u8], vec: &'other [u64]) -> BigVec<'data> {
271 let mut big_vec = BigVec { data };
272 for element in vec {
273 big_vec.push(TestStruct::new(*element)).unwrap();
274 }
275 big_vec
276 }
277
278 fn check_big_vec_eq(big_vec: &BigVec, slice: &[u64]) {
279 assert!(big_vec
280 .iter::<TestStruct>()
281 .map(|x| &x.value)
282 .zip(slice.iter())
283 .all(|(a, b)| a == b));
284 }
285
286 #[test]
287 fn push() {
288 let mut data = [0u8; 4 + 8 * 3];
289 let mut v = BigVec { data: &mut data };
290 v.push(TestStruct::new(1)).unwrap();
291 check_big_vec_eq(&v, &[1]);
292 v.push(TestStruct::new(2)).unwrap();
293 check_big_vec_eq(&v, &[1, 2]);
294 v.push(TestStruct::new(3)).unwrap();
295 check_big_vec_eq(&v, &[1, 2, 3]);
296 assert_eq!(
297 v.push(TestStruct::new(4)).unwrap_err(),
298 ProgramError::AccountDataTooSmall
299 );
300 }
301
302 #[test]
303 fn retain() {
304 fn mod_2_predicate(data: &[u8]) -> bool {
305 u64::try_from_slice(data).unwrap() % 2 == 0
306 }
307
308 let mut data = [0u8; 4 + 8 * 4];
309 let mut v = from_slice(&mut data, &[1, 2, 3, 4]);
310 v.retain::<TestStruct>(mod_2_predicate).unwrap();
311 check_big_vec_eq(&v, &[2, 4]);
312 }
313
314 fn find_predicate(a: &[u8], b: u64) -> bool {
315 if a.len() != 8 {
316 false
317 } else {
318 u64::try_from_slice(&a[0..8]).unwrap() == b
319 }
320 }
321
322 #[test]
323 fn find() {
324 let mut data = [0u8; 4 + 8 * 4];
325 let v = from_slice(&mut data, &[1, 2, 3, 4]);
326 assert_eq!(
327 v.find::<TestStruct, _>(|x| find_predicate(x, 1)),
328 Some(&TestStruct::new(1))
329 );
330 assert_eq!(
331 v.find::<TestStruct, _>(|x| find_predicate(x, 4)),
332 Some(&TestStruct::new(4))
333 );
334 assert_eq!(v.find::<TestStruct, _>(|x| find_predicate(x, 5)), None);
335 }
336
337 #[test]
338 fn find_mut() {
339 let mut data = [0u8; 4 + 8 * 4];
340 let mut v = from_slice(&mut data, &[1, 2, 3, 4]);
341 let mut test_struct = v
342 .find_mut::<TestStruct, _>(|x| find_predicate(x, 1))
343 .unwrap();
344 test_struct.value = 0;
345 check_big_vec_eq(&v, &[0, 2, 3, 4]);
346 assert_eq!(v.find_mut::<TestStruct, _>(|x| find_predicate(x, 5)), None);
347 }
348
349 #[test]
350 fn deserialize_mut_slice() {
351 let mut data = [0u8; 4 + 8 * 4];
352 let mut v = from_slice(&mut data, &[1, 2, 3, 4]);
353 let mut slice = v.deserialize_mut_slice::<TestStruct>(1, 2).unwrap();
354 slice[0].value = 10;
355 slice[1].value = 11;
356 check_big_vec_eq(&v, &[1, 10, 11, 4]);
357 assert_eq!(
358 v.deserialize_mut_slice::<TestStruct>(1, 4).unwrap_err(),
359 ProgramError::AccountDataTooSmall
360 );
361 assert_eq!(
362 v.deserialize_mut_slice::<TestStruct>(4, 1).unwrap_err(),
363 ProgramError::AccountDataTooSmall
364 );
365 }
366}