1use {
4 crate::{
5 error::PodSliceError, list::list_trait::List, pod_length::PodLength, primitives::PodU32,
6 },
7 bytemuck::Pod,
8 solana_program_error::ProgramError,
9 std::ops::{Deref, DerefMut},
10};
11
12#[derive(Debug)]
13pub struct ListViewMut<'data, T: Pod, L: PodLength = PodU32> {
14 pub(crate) length: &'data mut L,
15 pub(crate) data: &'data mut [T],
16 pub(crate) capacity: usize,
17}
18
19impl<T: Pod, L> ListViewMut<'_, T, L>
20where
21 L: PodLength,
22 PodSliceError: From<<L as TryFrom<usize>>::Error>,
23{
24 pub fn push(&mut self, item: T) -> Result<(), ProgramError> {
26 let length = (*self.length).into();
27 if length >= self.capacity {
28 Err(PodSliceError::BufferTooSmall.into())
29 } else {
30 self.data[length] = item;
31 *self.length = L::try_from(length.saturating_add(1)).map_err(PodSliceError::from)?;
32 Ok(())
33 }
34 }
35
36 pub fn remove(&mut self, index: usize) -> Result<T, ProgramError> {
39 let len = (*self.length).into();
40 if index >= len {
41 return Err(ProgramError::InvalidArgument);
42 }
43
44 let removed_item = self.data[index];
45
46 let tail_start = index
48 .checked_add(1)
49 .ok_or(ProgramError::ArithmeticOverflow)?;
50 self.data.copy_within(tail_start..len, index);
51
52 let new_len = len.checked_sub(1).unwrap();
54 *self.length = L::try_from(new_len).map_err(PodSliceError::from)?;
55
56 Ok(removed_item)
57 }
58}
59
60impl<T: Pod, L: PodLength> Deref for ListViewMut<'_, T, L> {
61 type Target = [T];
62
63 fn deref(&self) -> &Self::Target {
64 let len = (*self.length).into();
65 &self.data[..len]
66 }
67}
68
69impl<T: Pod, L: PodLength> DerefMut for ListViewMut<'_, T, L> {
70 fn deref_mut(&mut self) -> &mut Self::Target {
71 let len = (*self.length).into();
72 &mut self.data[..len]
73 }
74}
75
76impl<T: Pod, L: PodLength> List for ListViewMut<'_, T, L> {
77 type Item = T;
78 type Length = L;
79
80 fn capacity(&self) -> usize {
81 self.capacity
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use {
88 super::*,
89 crate::{
90 list::{List, ListView},
91 primitives::{PodU16, PodU32, PodU64},
92 },
93 bytemuck_derive::{Pod, Zeroable},
94 };
95
96 #[repr(C)]
97 #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Pod, Zeroable)]
98 struct TestStruct {
99 a: u64,
100 b: u32,
101 _padding: [u8; 4],
102 }
103
104 impl TestStruct {
105 fn new(a: u64, b: u32) -> Self {
106 Self {
107 a,
108 b,
109 _padding: [0; 4],
110 }
111 }
112 }
113
114 fn init_view_mut<T: Pod, L: PodLength>(
115 buffer: &mut Vec<u8>,
116 capacity: usize,
117 ) -> ListViewMut<T, L>
118 where
119 PodSliceError: From<<L as TryFrom<usize>>::Error>,
120 {
121 let size = ListView::<T, L>::size_of(capacity).unwrap();
122 buffer.resize(size, 0);
123 ListView::<T, L>::init(buffer).unwrap()
124 }
125
126 #[test]
127 fn test_push() {
128 let mut buffer = vec![];
129 let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 3);
130
131 assert_eq!(view.len(), 0);
132 assert!(view.is_empty());
133 assert_eq!(view.capacity(), 3);
134
135 let item1 = TestStruct::new(1, 10);
137 view.push(item1).unwrap();
138 assert_eq!(view.len(), 1);
139 assert!(!view.is_empty());
140 assert_eq!(*view, [item1]);
141
142 let item2 = TestStruct::new(2, 20);
144 view.push(item2).unwrap();
145 assert_eq!(view.len(), 2);
146 assert_eq!(*view, [item1, item2]);
147
148 let item3 = TestStruct::new(3, 30);
150 view.push(item3).unwrap();
151 assert_eq!(view.len(), 3);
152 assert_eq!(*view, [item1, item2, item3]);
153
154 let item4 = TestStruct::new(4, 40);
156 let err = view.push(item4).unwrap_err();
157 assert_eq!(err, PodSliceError::BufferTooSmall.into());
158
159 assert_eq!(view.len(), 3);
161 assert_eq!(*view, [item1, item2, item3]);
162 }
163
164 #[test]
165 fn test_remove() {
166 let mut buffer = vec![];
167 let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 4);
168
169 let item1 = TestStruct::new(1, 10);
170 let item2 = TestStruct::new(2, 20);
171 let item3 = TestStruct::new(3, 30);
172 let item4 = TestStruct::new(4, 40);
173 view.push(item1).unwrap();
174 view.push(item2).unwrap();
175 view.push(item3).unwrap();
176 view.push(item4).unwrap();
177
178 assert_eq!(view.len(), 4);
179 assert_eq!(*view, [item1, item2, item3, item4]);
180
181 let removed = view.remove(1).unwrap();
183 assert_eq!(removed, item2);
184 assert_eq!(view.len(), 3);
185 assert_eq!(*view, [item1, item3, item4]);
186
187 let removed = view.remove(2).unwrap();
189 assert_eq!(removed, item4);
190 assert_eq!(view.len(), 2);
191 assert_eq!(*view, [item1, item3]);
192
193 let removed = view.remove(0).unwrap();
195 assert_eq!(removed, item1);
196 assert_eq!(view.len(), 1);
197 assert_eq!(*view, [item3]);
198
199 let removed = view.remove(0).unwrap();
201 assert_eq!(removed, item3);
202 assert_eq!(view.len(), 0);
203 assert!(view.is_empty());
204 assert_eq!(*view, []);
205 }
206
207 #[test]
208 fn test_remove_out_of_bounds() {
209 let mut buffer = vec![];
210 let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 2);
211
212 view.push(TestStruct::new(1, 10)).unwrap();
213 view.push(TestStruct::new(2, 20)).unwrap();
214
215 let err = view.remove(2).unwrap_err();
217 assert_eq!(err, ProgramError::InvalidArgument);
218 assert_eq!(view.len(), 2); let err = view.remove(100).unwrap_err();
222 assert_eq!(err, ProgramError::InvalidArgument);
223 assert_eq!(view.len(), 2); view.remove(1).unwrap();
227 view.remove(0).unwrap();
228 assert!(view.is_empty());
229
230 let err = view.remove(0).unwrap_err();
232 assert_eq!(err, ProgramError::InvalidArgument);
233 }
234
235 #[test]
236 fn test_iter_mut() {
237 let mut buffer = vec![];
238 let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 4);
239
240 let item1 = TestStruct::new(1, 10);
241 let item2 = TestStruct::new(2, 20);
242 let item3 = TestStruct::new(3, 30);
243 view.push(item1).unwrap();
244 view.push(item2).unwrap();
245 view.push(item3).unwrap();
246
247 assert_eq!(view.len(), 3);
248 assert_eq!(view.capacity(), 4);
249
250 for item in view.iter_mut() {
252 item.a *= 10;
253 }
254
255 let expected_item1 = TestStruct::new(10, 10);
256 let expected_item2 = TestStruct::new(20, 20);
257 let expected_item3 = TestStruct::new(30, 30);
258
259 assert_eq!(view.len(), 3);
261 assert_eq!(*view, [expected_item1, expected_item2, expected_item3]);
262
263 assert_eq!(view.iter_mut().count(), 3);
265 }
266
267 #[test]
268 fn test_iter_mut_empty() {
269 let mut buffer = vec![];
270 let mut view = init_view_mut::<TestStruct, PodU64>(&mut buffer, 5);
271
272 let mut count = 0;
273 for _ in view.iter_mut() {
274 count += 1;
275 }
276 assert_eq!(count, 0);
277 assert_eq!(view.iter_mut().next(), None);
278 }
279
280 #[test]
281 fn test_zero_capacity() {
282 let mut buffer = vec![];
283 let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 0);
284
285 assert_eq!(view.len(), 0);
286 assert_eq!(view.capacity(), 0);
287 assert!(view.is_empty());
288
289 let err = view.push(TestStruct::new(1, 1)).unwrap_err();
290 assert_eq!(err, PodSliceError::BufferTooSmall.into());
291
292 let err = view.remove(0).unwrap_err();
293 assert_eq!(err, ProgramError::InvalidArgument);
294 }
295
296 #[test]
297 fn test_default_length_type() {
298 let capacity = 2;
299 let mut buffer = vec![];
300 let size = ListView::<TestStruct, PodU64>::size_of(capacity).unwrap();
301 buffer.resize(size, 0);
302
303 let view = ListView::<TestStruct>::init(&mut buffer).unwrap();
305
306 assert_eq!(view.capacity(), capacity);
308 assert_eq!(view.len(), 0);
309
310 assert_eq!(size_of_val(view.length), size_of::<PodU32>());
312 }
313
314 #[test]
315 fn test_bytes_used_and_allocated_mut() {
316 let mut buffer = vec![];
318 let mut view = init_view_mut::<TestStruct, PodU16>(&mut buffer, 3);
319
320 assert_eq!(
322 view.bytes_used().unwrap(),
323 ListView::<TestStruct, PodU32>::size_of(0).unwrap()
324 );
325 assert_eq!(
326 view.bytes_allocated().unwrap(),
327 ListView::<TestStruct, PodU32>::size_of(view.capacity()).unwrap()
328 );
329
330 view.push(TestStruct::new(1, 2)).unwrap();
332 view.push(TestStruct::new(3, 4)).unwrap();
333 view.push(TestStruct::new(5, 6)).unwrap();
334 assert_eq!(
335 view.bytes_used().unwrap(),
336 ListView::<TestStruct, PodU32>::size_of(3).unwrap()
337 );
338 assert_eq!(
339 view.bytes_allocated().unwrap(),
340 ListView::<TestStruct, PodU32>::size_of(view.capacity()).unwrap()
341 );
342 }
343 #[test]
344 fn test_get_and_get_mut() {
345 let mut buffer = vec![];
346 let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 3);
347
348 let item0 = TestStruct::new(1, 10);
349 let item1 = TestStruct::new(2, 20);
350 view.push(item0).unwrap();
351 view.push(item1).unwrap();
352
353 assert_eq!(view.first(), Some(&item0));
355 assert_eq!(view.get(1), Some(&item1));
356 assert_eq!(view.get(2), None); assert_eq!(view.get(100), None); let modified_item0 = TestStruct::new(111, 110);
361 let item_ref = view.get_mut(0).unwrap();
362 *item_ref = modified_item0;
363
364 assert_eq!(view.first(), Some(&modified_item0));
366 assert_eq!(*view, [modified_item0, item1]);
367
368 assert_eq!(view.get_mut(2), None);
370 }
371
372 #[test]
373 fn test_mutable_access_via_indexing() {
374 let mut buffer = vec![];
375 let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 3);
376
377 let item0 = TestStruct::new(1, 10);
378 let item1 = TestStruct::new(2, 20);
379 view.push(item0).unwrap();
380 view.push(item1).unwrap();
381
382 assert_eq!(view.len(), 2);
383
384 view[0].a = 99;
386
387 let expected_item0 = TestStruct::new(99, 10);
388 assert_eq!(view.first(), Some(&expected_item0));
389 assert_eq!(*view, [expected_item0, item1]);
390 }
391
392 #[test]
393 fn test_sort_by() {
394 let mut buffer = vec![];
395 let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 5);
396
397 let item0 = TestStruct::new(5, 1);
398 let item1 = TestStruct::new(2, 2);
399 let item2 = TestStruct::new(5, 3);
400 let item3 = TestStruct::new(1, 4);
401 let item4 = TestStruct::new(2, 5);
402
403 view.push(item0).unwrap();
404 view.push(item1).unwrap();
405 view.push(item2).unwrap();
406 view.push(item3).unwrap();
407 view.push(item4).unwrap();
408
409 view.sort_by(|a, b| b.b.cmp(&a.b));
411 let expected_order_by_b_desc = [
412 item4, item3, item2, item1, item0, ];
418 assert_eq!(*view, expected_order_by_b_desc);
419
420 view.sort_by(|x, y| x.a.cmp(&y.a));
423
424 let expected_order_by_a_stable = [
425 item3, item4, item1, item2, item0, ];
431 assert_eq!(*view, expected_order_by_a_stable);
432 }
433}