1#![allow(clippy::arithmetic_side_effects)] use {
5 arrayref::array_ref,
6 borsh::BorshDeserialize,
7 bytemuck::Pod,
8 solana_program::{program_error::ProgramError, program_memory::sol_memmove},
9 std::mem,
10};
11
12pub struct BigVec<'data> {
15 pub data: &'data mut [u8],
17}
18
19const VEC_SIZE_BYTES: usize = 4;
20
21impl BigVec<'_> {
22 pub fn len(&self) -> u32 {
24 let vec_len = array_ref![self.data, 0, VEC_SIZE_BYTES];
25 u32::from_le_bytes(*vec_len)
26 }
27
28 pub fn is_empty(&self) -> bool {
30 self.len() == 0
31 }
32
33 pub fn retain<T: Pod, F: Fn(&[u8]) -> bool>(
35 &mut self,
36 predicate: F,
37 ) -> Result<(), ProgramError> {
38 let mut vec_len = self.len();
39 let mut removals_found = 0;
40 let mut dst_start_index = 0;
41
42 let data_start_index = VEC_SIZE_BYTES;
43 let data_end_index =
44 data_start_index.saturating_add((vec_len as usize).saturating_mul(mem::size_of::<T>()));
45 for start_index in (data_start_index..data_end_index).step_by(mem::size_of::<T>()) {
46 let end_index = start_index + mem::size_of::<T>();
47 let slice = &self.data[start_index..end_index];
48 if !predicate(slice) {
49 let gap = removals_found * mem::size_of::<T>();
50 if removals_found > 0 {
51 unsafe {
55 sol_memmove(
56 self.data[dst_start_index..start_index - gap].as_mut_ptr(),
57 self.data[dst_start_index + gap..start_index].as_mut_ptr(),
58 start_index - gap - dst_start_index,
59 );
60 }
61 }
62 dst_start_index = start_index - gap;
63 removals_found += 1;
64 vec_len -= 1;
65 }
66 }
67
68 if removals_found > 0 {
70 let gap = removals_found * mem::size_of::<T>();
71 unsafe {
78 sol_memmove(
79 self.data[dst_start_index..data_end_index - gap].as_mut_ptr(),
80 self.data[dst_start_index + gap..data_end_index].as_mut_ptr(),
81 data_end_index - gap - dst_start_index,
82 );
83 }
84 }
85
86 let vec_len_ref = &mut self.data[0..VEC_SIZE_BYTES];
87 borsh::to_writer(vec_len_ref, &vec_len)?;
88
89 Ok(())
90 }
91
92 pub fn deserialize_mut_slice<T: Pod>(
94 &mut self,
95 skip: usize,
96 len: usize,
97 ) -> Result<&mut [T], ProgramError> {
98 let vec_len = self.len();
99 let last_item_index = skip
100 .checked_add(len)
101 .ok_or(ProgramError::AccountDataTooSmall)?;
102 if last_item_index > vec_len as usize {
103 return Err(ProgramError::AccountDataTooSmall);
104 }
105
106 let start_index = VEC_SIZE_BYTES.saturating_add(skip.saturating_mul(mem::size_of::<T>()));
107 let end_index = start_index.saturating_add(len.saturating_mul(mem::size_of::<T>()));
108 bytemuck::try_cast_slice_mut(&mut self.data[start_index..end_index])
109 .map_err(|_| ProgramError::InvalidAccountData)
110 }
111
112 pub fn deserialize_slice<T: Pod>(&self, skip: usize, len: usize) -> Result<&[T], ProgramError> {
114 let vec_len = self.len();
115 let last_item_index = skip
116 .checked_add(len)
117 .ok_or(ProgramError::AccountDataTooSmall)?;
118 if last_item_index > vec_len as usize {
119 return Err(ProgramError::AccountDataTooSmall);
120 }
121
122 let start_index = VEC_SIZE_BYTES.saturating_add(skip.saturating_mul(mem::size_of::<T>()));
123 let end_index = start_index.saturating_add(len.saturating_mul(mem::size_of::<T>()));
124 bytemuck::try_cast_slice(&self.data[start_index..end_index])
125 .map_err(|_| ProgramError::InvalidAccountData)
126 }
127
128 pub fn push<T: Pod>(&mut self, element: T) -> Result<(), ProgramError> {
130 let vec_len_ref = &mut self.data[0..VEC_SIZE_BYTES];
131 let mut vec_len = u32::try_from_slice(vec_len_ref)?;
132
133 let start_index = VEC_SIZE_BYTES + vec_len as usize * mem::size_of::<T>();
134 let end_index = start_index + mem::size_of::<T>();
135
136 vec_len += 1;
137 borsh::to_writer(vec_len_ref, &vec_len)?;
138
139 if self.data.len() < end_index {
140 return Err(ProgramError::AccountDataTooSmall);
141 }
142 let element_ref = bytemuck::try_from_bytes_mut(
143 &mut self.data[start_index..start_index + mem::size_of::<T>()],
144 )
145 .map_err(|_| ProgramError::InvalidAccountData)?;
146 *element_ref = element;
147 Ok(())
148 }
149
150 pub fn find<T: Pod, F: Fn(&[u8]) -> bool>(&self, predicate: F) -> Option<&T> {
152 let len = self.len() as usize;
153 let mut current = 0;
154 let mut current_index = VEC_SIZE_BYTES;
155 while current != len {
156 let end_index = current_index + mem::size_of::<T>();
157 let current_slice = &self.data[current_index..end_index];
158 if predicate(current_slice) {
159 return Some(bytemuck::from_bytes(current_slice));
160 }
161 current_index = end_index;
162 current += 1;
163 }
164 None
165 }
166
167 pub fn find_mut<T: Pod, F: Fn(&[u8]) -> bool>(&mut self, predicate: F) -> Option<&mut T> {
169 let len = self.len() as usize;
170 let mut current = 0;
171 let mut current_index = VEC_SIZE_BYTES;
172 while current != len {
173 let end_index = current_index + mem::size_of::<T>();
174 let current_slice = &self.data[current_index..end_index];
175 if predicate(current_slice) {
176 return Some(bytemuck::from_bytes_mut(
177 &mut self.data[current_index..end_index],
178 ));
179 }
180 current_index = end_index;
181 current += 1;
182 }
183 None
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use {super::*, bytemuck::Zeroable};
190
191 #[repr(C)]
192 #[derive(Debug, Copy, Clone, PartialEq, Pod, Zeroable)]
193 struct TestStruct {
194 value: [u8; 8],
195 }
196
197 impl TestStruct {
198 fn new(value: u8) -> Self {
199 let value = [value, 0, 0, 0, 0, 0, 0, 0];
200 Self { value }
201 }
202 }
203
204 fn from_slice<'data>(data: &'data mut [u8], vec: &[u8]) -> BigVec<'data> {
205 let mut big_vec = BigVec { data };
206 for element in vec {
207 big_vec.push(TestStruct::new(*element)).unwrap();
208 }
209 big_vec
210 }
211
212 fn check_big_vec_eq(big_vec: &BigVec, slice: &[u8]) {
213 assert!(big_vec
214 .deserialize_slice::<TestStruct>(0, big_vec.len() as usize)
215 .unwrap()
216 .iter()
217 .map(|x| &x.value[0])
218 .zip(slice.iter())
219 .all(|(a, b)| a == b));
220 }
221
222 #[test]
223 fn push() {
224 let mut data = [0u8; 4 + 8 * 3];
225 let mut v = BigVec { data: &mut data };
226 v.push(TestStruct::new(1)).unwrap();
227 check_big_vec_eq(&v, &[1]);
228 v.push(TestStruct::new(2)).unwrap();
229 check_big_vec_eq(&v, &[1, 2]);
230 v.push(TestStruct::new(3)).unwrap();
231 check_big_vec_eq(&v, &[1, 2, 3]);
232 assert_eq!(
233 v.push(TestStruct::new(4)).unwrap_err(),
234 ProgramError::AccountDataTooSmall
235 );
236 }
237
238 #[test]
239 fn retain() {
240 fn mod_2_predicate(data: &[u8]) -> bool {
241 u64::try_from_slice(data).unwrap() % 2 == 0
242 }
243
244 let mut data = [0u8; 4 + 8 * 4];
245 let mut v = from_slice(&mut data, &[1, 2, 3, 4]);
246 v.retain::<TestStruct, _>(mod_2_predicate).unwrap();
247 check_big_vec_eq(&v, &[2, 4]);
248 }
249
250 fn find_predicate(a: &[u8], b: u8) -> bool {
251 if a.len() != 8 {
252 false
253 } else {
254 a[0] == b
255 }
256 }
257
258 #[test]
259 fn find() {
260 let mut data = [0u8; 4 + 8 * 4];
261 let v = from_slice(&mut data, &[1, 2, 3, 4]);
262 assert_eq!(
263 v.find::<TestStruct, _>(|x| find_predicate(x, 1)),
264 Some(&TestStruct::new(1))
265 );
266 assert_eq!(
267 v.find::<TestStruct, _>(|x| find_predicate(x, 4)),
268 Some(&TestStruct::new(4))
269 );
270 assert_eq!(v.find::<TestStruct, _>(|x| find_predicate(x, 5)), None);
271 }
272
273 #[test]
274 fn find_mut() {
275 let mut data = [0u8; 4 + 8 * 4];
276 let mut v = from_slice(&mut data, &[1, 2, 3, 4]);
277 let test_struct = v
278 .find_mut::<TestStruct, _>(|x| find_predicate(x, 1))
279 .unwrap();
280 test_struct.value = [0; 8];
281 check_big_vec_eq(&v, &[0, 2, 3, 4]);
282 assert_eq!(v.find_mut::<TestStruct, _>(|x| find_predicate(x, 5)), None);
283 }
284
285 #[test]
286 fn deserialize_mut_slice() {
287 let mut data = [0u8; 4 + 8 * 4];
288 let mut v = from_slice(&mut data, &[1, 2, 3, 4]);
289 let slice = v.deserialize_mut_slice::<TestStruct>(1, 2).unwrap();
290 slice[0].value[0] = 10;
291 slice[1].value[0] = 11;
292 check_big_vec_eq(&v, &[1, 10, 11, 4]);
293 assert_eq!(
294 v.deserialize_mut_slice::<TestStruct>(1, 4).unwrap_err(),
295 ProgramError::AccountDataTooSmall
296 );
297 assert_eq!(
298 v.deserialize_mut_slice::<TestStruct>(4, 1).unwrap_err(),
299 ProgramError::AccountDataTooSmall
300 );
301 }
302}