Skip to main content

spl_pod/list/
list_view_mut.rs

1//! `ListViewMut`, a mutable, compact, zero-copy array wrapper.
2
3use {
4    crate::{
5        error::PodSliceError, list::list_trait::List, pod_length::PodLength, primitives::PodU32,
6    },
7    bytemuck::Pod,
8    solana_program_error::ProgramError,
9    std::ops::{Deref, DerefMut},
10};
11
12#[derive(Debug)]
13pub struct ListViewMut<'data, T: Pod, L: PodLength = PodU32> {
14    pub(crate) length: &'data mut L,
15    pub(crate) data: &'data mut [T],
16    pub(crate) capacity: usize,
17}
18
19impl<T: Pod, L> ListViewMut<'_, T, L>
20where
21    L: PodLength,
22    PodSliceError: From<<L as TryFrom<usize>>::Error>,
23{
24    /// Add another item to the slice
25    pub fn push(&mut self, item: T) -> Result<(), ProgramError> {
26        let length = (*self.length).into();
27        if length >= self.capacity {
28            Err(PodSliceError::BufferTooSmall.into())
29        } else {
30            self.data[length] = item;
31            *self.length = L::try_from(length.saturating_add(1)).map_err(PodSliceError::from)?;
32            Ok(())
33        }
34    }
35
36    /// Remove and return the element at `index`, shifting all later
37    /// elements one position to the left.
38    pub fn remove(&mut self, index: usize) -> Result<T, ProgramError> {
39        let len = (*self.length).into();
40        if index >= len {
41            return Err(ProgramError::InvalidArgument);
42        }
43
44        let removed_item = self.data[index];
45
46        // Move the tail left by one
47        let tail_start = index
48            .checked_add(1)
49            .ok_or(ProgramError::ArithmeticOverflow)?;
50        self.data.copy_within(tail_start..len, index);
51
52        // Store the new length (len - 1)
53        let new_len = len.checked_sub(1).unwrap();
54        *self.length = L::try_from(new_len).map_err(PodSliceError::from)?;
55
56        Ok(removed_item)
57    }
58}
59
60impl<T: Pod, L: PodLength> Deref for ListViewMut<'_, T, L> {
61    type Target = [T];
62
63    fn deref(&self) -> &Self::Target {
64        let len = (*self.length).into();
65        &self.data[..len]
66    }
67}
68
69impl<T: Pod, L: PodLength> DerefMut for ListViewMut<'_, T, L> {
70    fn deref_mut(&mut self) -> &mut Self::Target {
71        let len = (*self.length).into();
72        &mut self.data[..len]
73    }
74}
75
76impl<T: Pod, L: PodLength> List for ListViewMut<'_, T, L> {
77    type Item = T;
78    type Length = L;
79
80    fn capacity(&self) -> usize {
81        self.capacity
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use {
88        super::*,
89        crate::{
90            list::{List, ListView},
91            primitives::{PodU16, PodU32, PodU64},
92        },
93        bytemuck_derive::{Pod, Zeroable},
94    };
95
96    #[repr(C)]
97    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Pod, Zeroable)]
98    struct TestStruct {
99        a: u64,
100        b: u32,
101        _padding: [u8; 4],
102    }
103
104    impl TestStruct {
105        fn new(a: u64, b: u32) -> Self {
106            Self {
107                a,
108                b,
109                _padding: [0; 4],
110            }
111        }
112    }
113
114    fn init_view_mut<T: Pod, L: PodLength>(
115        buffer: &mut Vec<u8>,
116        capacity: usize,
117    ) -> ListViewMut<T, L>
118    where
119        PodSliceError: From<<L as TryFrom<usize>>::Error>,
120    {
121        let size = ListView::<T, L>::size_of(capacity).unwrap();
122        buffer.resize(size, 0);
123        ListView::<T, L>::init(buffer).unwrap()
124    }
125
126    #[test]
127    fn test_push() {
128        let mut buffer = vec![];
129        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 3);
130
131        assert_eq!(view.len(), 0);
132        assert!(view.is_empty());
133        assert_eq!(view.capacity(), 3);
134
135        // Push first item
136        let item1 = TestStruct::new(1, 10);
137        view.push(item1).unwrap();
138        assert_eq!(view.len(), 1);
139        assert!(!view.is_empty());
140        assert_eq!(*view, [item1]);
141
142        // Push second item
143        let item2 = TestStruct::new(2, 20);
144        view.push(item2).unwrap();
145        assert_eq!(view.len(), 2);
146        assert_eq!(*view, [item1, item2]);
147
148        // Push third item to fill capacity
149        let item3 = TestStruct::new(3, 30);
150        view.push(item3).unwrap();
151        assert_eq!(view.len(), 3);
152        assert_eq!(*view, [item1, item2, item3]);
153
154        // Try to push beyond capacity
155        let item4 = TestStruct::new(4, 40);
156        let err = view.push(item4).unwrap_err();
157        assert_eq!(err, PodSliceError::BufferTooSmall.into());
158
159        // Ensure state is unchanged
160        assert_eq!(view.len(), 3);
161        assert_eq!(*view, [item1, item2, item3]);
162    }
163
164    #[test]
165    fn test_remove() {
166        let mut buffer = vec![];
167        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 4);
168
169        let item1 = TestStruct::new(1, 10);
170        let item2 = TestStruct::new(2, 20);
171        let item3 = TestStruct::new(3, 30);
172        let item4 = TestStruct::new(4, 40);
173        view.push(item1).unwrap();
174        view.push(item2).unwrap();
175        view.push(item3).unwrap();
176        view.push(item4).unwrap();
177
178        assert_eq!(view.len(), 4);
179        assert_eq!(*view, [item1, item2, item3, item4]);
180
181        // Remove from the middle
182        let removed = view.remove(1).unwrap();
183        assert_eq!(removed, item2);
184        assert_eq!(view.len(), 3);
185        assert_eq!(*view, [item1, item3, item4]);
186
187        // Remove from the end
188        let removed = view.remove(2).unwrap();
189        assert_eq!(removed, item4);
190        assert_eq!(view.len(), 2);
191        assert_eq!(*view, [item1, item3]);
192
193        // Remove from the start
194        let removed = view.remove(0).unwrap();
195        assert_eq!(removed, item1);
196        assert_eq!(view.len(), 1);
197        assert_eq!(*view, [item3]);
198
199        // Remove the last element
200        let removed = view.remove(0).unwrap();
201        assert_eq!(removed, item3);
202        assert_eq!(view.len(), 0);
203        assert!(view.is_empty());
204        assert_eq!(*view, []);
205    }
206
207    #[test]
208    fn test_remove_out_of_bounds() {
209        let mut buffer = vec![];
210        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 2);
211
212        view.push(TestStruct::new(1, 10)).unwrap();
213        view.push(TestStruct::new(2, 20)).unwrap();
214
215        // Try to remove at index == len
216        let err = view.remove(2).unwrap_err();
217        assert_eq!(err, ProgramError::InvalidArgument);
218        assert_eq!(view.len(), 2); // Unchanged
219
220        // Try to remove at index > len
221        let err = view.remove(100).unwrap_err();
222        assert_eq!(err, ProgramError::InvalidArgument);
223        assert_eq!(view.len(), 2); // Unchanged
224
225        // Empty the view
226        view.remove(1).unwrap();
227        view.remove(0).unwrap();
228        assert!(view.is_empty());
229
230        // Try to remove from empty view
231        let err = view.remove(0).unwrap_err();
232        assert_eq!(err, ProgramError::InvalidArgument);
233    }
234
235    #[test]
236    fn test_iter_mut() {
237        let mut buffer = vec![];
238        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 4);
239
240        let item1 = TestStruct::new(1, 10);
241        let item2 = TestStruct::new(2, 20);
242        let item3 = TestStruct::new(3, 30);
243        view.push(item1).unwrap();
244        view.push(item2).unwrap();
245        view.push(item3).unwrap();
246
247        assert_eq!(view.len(), 3);
248        assert_eq!(view.capacity(), 4);
249
250        // Modify items using iter_mut
251        for item in view.iter_mut() {
252            item.a *= 10;
253        }
254
255        let expected_item1 = TestStruct::new(10, 10);
256        let expected_item2 = TestStruct::new(20, 20);
257        let expected_item3 = TestStruct::new(30, 30);
258
259        // Check that the underlying data is modified
260        assert_eq!(view.len(), 3);
261        assert_eq!(*view, [expected_item1, expected_item2, expected_item3]);
262
263        // Check that iter_mut only iterates over `len` items, not `capacity`
264        assert_eq!(view.iter_mut().count(), 3);
265    }
266
267    #[test]
268    fn test_iter_mut_empty() {
269        let mut buffer = vec![];
270        let mut view = init_view_mut::<TestStruct, PodU64>(&mut buffer, 5);
271
272        let mut count = 0;
273        for _ in view.iter_mut() {
274            count += 1;
275        }
276        assert_eq!(count, 0);
277        assert_eq!(view.iter_mut().next(), None);
278    }
279
280    #[test]
281    fn test_zero_capacity() {
282        let mut buffer = vec![];
283        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 0);
284
285        assert_eq!(view.len(), 0);
286        assert_eq!(view.capacity(), 0);
287        assert!(view.is_empty());
288
289        let err = view.push(TestStruct::new(1, 1)).unwrap_err();
290        assert_eq!(err, PodSliceError::BufferTooSmall.into());
291
292        let err = view.remove(0).unwrap_err();
293        assert_eq!(err, ProgramError::InvalidArgument);
294    }
295
296    #[test]
297    fn test_default_length_type() {
298        let capacity = 2;
299        let mut buffer = vec![];
300        let size = ListView::<TestStruct, PodU64>::size_of(capacity).unwrap();
301        buffer.resize(size, 0);
302
303        // Initialize the view *without* specifying L. The compiler uses the default.
304        let view = ListView::<TestStruct>::init(&mut buffer).unwrap();
305
306        // Check that the capacity is correct for a PodU64 length.
307        assert_eq!(view.capacity(), capacity);
308        assert_eq!(view.len(), 0);
309
310        // Verify the size of the length field.
311        assert_eq!(size_of_val(view.length), size_of::<PodU32>());
312    }
313
314    #[test]
315    fn test_bytes_used_and_allocated_mut() {
316        // capacity 3, start empty
317        let mut buffer = vec![];
318        let mut view = init_view_mut::<TestStruct, PodU16>(&mut buffer, 3);
319
320        // Empty view
321        assert_eq!(
322            view.bytes_used().unwrap(),
323            ListView::<TestStruct, PodU32>::size_of(0).unwrap()
324        );
325        assert_eq!(
326            view.bytes_allocated().unwrap(),
327            ListView::<TestStruct, PodU32>::size_of(view.capacity()).unwrap()
328        );
329
330        // After pushing elements
331        view.push(TestStruct::new(1, 2)).unwrap();
332        view.push(TestStruct::new(3, 4)).unwrap();
333        view.push(TestStruct::new(5, 6)).unwrap();
334        assert_eq!(
335            view.bytes_used().unwrap(),
336            ListView::<TestStruct, PodU32>::size_of(3).unwrap()
337        );
338        assert_eq!(
339            view.bytes_allocated().unwrap(),
340            ListView::<TestStruct, PodU32>::size_of(view.capacity()).unwrap()
341        );
342    }
343    #[test]
344    fn test_get_and_get_mut() {
345        let mut buffer = vec![];
346        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 3);
347
348        let item0 = TestStruct::new(1, 10);
349        let item1 = TestStruct::new(2, 20);
350        view.push(item0).unwrap();
351        view.push(item1).unwrap();
352
353        // Test get()
354        assert_eq!(view.first(), Some(&item0));
355        assert_eq!(view.get(1), Some(&item1));
356        assert_eq!(view.get(2), None); // out of bounds
357        assert_eq!(view.get(100), None); // way out of bounds
358
359        // Test get_mut() to modify an item
360        let modified_item0 = TestStruct::new(111, 110);
361        let item_ref = view.get_mut(0).unwrap();
362        *item_ref = modified_item0;
363
364        // Verify the modification
365        assert_eq!(view.first(), Some(&modified_item0));
366        assert_eq!(*view, [modified_item0, item1]);
367
368        // Test get_mut() out of bounds
369        assert_eq!(view.get_mut(2), None);
370    }
371
372    #[test]
373    fn test_mutable_access_via_indexing() {
374        let mut buffer = vec![];
375        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 3);
376
377        let item0 = TestStruct::new(1, 10);
378        let item1 = TestStruct::new(2, 20);
379        view.push(item0).unwrap();
380        view.push(item1).unwrap();
381
382        assert_eq!(view.len(), 2);
383
384        // Modify via the mutable slice
385        view[0].a = 99;
386
387        let expected_item0 = TestStruct::new(99, 10);
388        assert_eq!(view.first(), Some(&expected_item0));
389        assert_eq!(*view, [expected_item0, item1]);
390    }
391
392    #[test]
393    fn test_sort_by() {
394        let mut buffer = vec![];
395        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 5);
396
397        let item0 = TestStruct::new(5, 1);
398        let item1 = TestStruct::new(2, 2);
399        let item2 = TestStruct::new(5, 3);
400        let item3 = TestStruct::new(1, 4);
401        let item4 = TestStruct::new(2, 5);
402
403        view.push(item0).unwrap();
404        view.push(item1).unwrap();
405        view.push(item2).unwrap();
406        view.push(item3).unwrap();
407        view.push(item4).unwrap();
408
409        // Sort by `b` field in descending order.
410        view.sort_by(|a, b| b.b.cmp(&a.b));
411        let expected_order_by_b_desc = [
412            item4, // b: 5
413            item3, // b: 4
414            item2, // b: 3
415            item1, // b: 2
416            item0, // b: 1
417        ];
418        assert_eq!(*view, expected_order_by_b_desc);
419
420        // Now, sort by `a` in ascending order. A stable sort preserves the relative
421        // order of equal elements from the previous state of the list.
422        view.sort_by(|x, y| x.a.cmp(&y.a));
423
424        let expected_order_by_a_stable = [
425            item3, // a: 1
426            item4, // a: 2 (was before item1 in the previous state)
427            item1, // a: 2
428            item2, // a: 5 (was before item0 in the previous state)
429            item0, // a: 5
430        ];
431        assert_eq!(*view, expected_order_by_a_stable);
432    }
433}