Skip to main content

spl_list_view/
list_view_mut.rs

1//! `ListViewMut`, a mutable, compact, zero-copy array wrapper.
2
3use {
4    crate::list_trait::List,
5    bytemuck::Pod,
6    core::ops::{Deref, DerefMut},
7    solana_program_error::ProgramError,
8    spl_pod::{error::PodSliceError, pod_length::PodLength, primitives::PodU32},
9};
10
11#[derive(Debug)]
12pub struct ListViewMut<'data, T: Pod, L: PodLength = PodU32> {
13    pub(crate) length: &'data mut L,
14    pub(crate) data: &'data mut [T],
15    pub(crate) capacity: usize,
16}
17
18impl<T: Pod, L: PodLength> ListViewMut<'_, T, L> {
19    /// Add another item to the slice
20    pub fn push(&mut self, item: T) -> Result<(), ProgramError> {
21        let length = (*self.length).into();
22        if length >= self.capacity {
23            Err(PodSliceError::BufferTooSmall.into())
24        } else {
25            self.data[length] = item;
26            *self.length = L::try_from(length.saturating_add(1))?;
27            Ok(())
28        }
29    }
30
31    /// Remove and return the element at `index`, shifting all later
32    /// elements one position to the left.
33    pub fn remove(&mut self, index: usize) -> Result<T, ProgramError> {
34        let len = (*self.length).into();
35        if index >= len {
36            return Err(ProgramError::InvalidArgument);
37        }
38
39        let removed_item = self.data[index];
40
41        // Move the tail left by one
42        let tail_start = index
43            .checked_add(1)
44            .ok_or(ProgramError::ArithmeticOverflow)?;
45        self.data.copy_within(tail_start..len, index);
46
47        // Store the new length (len - 1)
48        let new_len = len.checked_sub(1).unwrap();
49        *self.length = L::try_from(new_len)?;
50
51        Ok(removed_item)
52    }
53}
54
55impl<T: Pod, L: PodLength> Deref for ListViewMut<'_, T, L> {
56    type Target = [T];
57
58    fn deref(&self) -> &Self::Target {
59        let len = (*self.length).into();
60        &self.data[..len]
61    }
62}
63
64impl<T: Pod, L: PodLength> DerefMut for ListViewMut<'_, T, L> {
65    fn deref_mut(&mut self) -> &mut Self::Target {
66        let len = (*self.length).into();
67        &mut self.data[..len]
68    }
69}
70
71impl<T: Pod, L: PodLength> List for ListViewMut<'_, T, L> {
72    type Item = T;
73    type Length = L;
74
75    fn capacity(&self) -> usize {
76        self.capacity
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use {
83        super::*,
84        crate::{List, ListView},
85        bytemuck_derive::{Pod, Zeroable},
86        spl_pod::primitives::{PodU16, PodU32, PodU64},
87    };
88
89    #[repr(C)]
90    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Pod, Zeroable)]
91    struct TestStruct {
92        a: u64,
93        b: u32,
94        _padding: [u8; 4],
95    }
96
97    impl TestStruct {
98        fn new(a: u64, b: u32) -> Self {
99            Self {
100                a,
101                b,
102                _padding: [0; 4],
103            }
104        }
105    }
106
107    fn init_view_mut<T: Pod, L: PodLength>(
108        buffer: &mut Vec<u8>,
109        capacity: usize,
110    ) -> ListViewMut<T, L> {
111        let size = ListView::<T, L>::size_of(capacity).unwrap();
112        buffer.resize(size, 0);
113        ListView::<T, L>::init(buffer).unwrap()
114    }
115
116    #[test]
117    fn test_push() {
118        let mut buffer = vec![];
119        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 3);
120
121        assert_eq!(view.len(), 0);
122        assert!(view.is_empty());
123        assert_eq!(view.capacity(), 3);
124
125        // Push first item
126        let item1 = TestStruct::new(1, 10);
127        view.push(item1).unwrap();
128        assert_eq!(view.len(), 1);
129        assert!(!view.is_empty());
130        assert_eq!(*view, [item1]);
131
132        // Push second item
133        let item2 = TestStruct::new(2, 20);
134        view.push(item2).unwrap();
135        assert_eq!(view.len(), 2);
136        assert_eq!(*view, [item1, item2]);
137
138        // Push third item to fill capacity
139        let item3 = TestStruct::new(3, 30);
140        view.push(item3).unwrap();
141        assert_eq!(view.len(), 3);
142        assert_eq!(*view, [item1, item2, item3]);
143
144        // Try to push beyond capacity
145        let item4 = TestStruct::new(4, 40);
146        let err = view.push(item4).unwrap_err();
147        assert_eq!(err, PodSliceError::BufferTooSmall.into());
148
149        // Ensure state is unchanged
150        assert_eq!(view.len(), 3);
151        assert_eq!(*view, [item1, item2, item3]);
152    }
153
154    #[test]
155    fn test_remove() {
156        let mut buffer = vec![];
157        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 4);
158
159        let item1 = TestStruct::new(1, 10);
160        let item2 = TestStruct::new(2, 20);
161        let item3 = TestStruct::new(3, 30);
162        let item4 = TestStruct::new(4, 40);
163        view.push(item1).unwrap();
164        view.push(item2).unwrap();
165        view.push(item3).unwrap();
166        view.push(item4).unwrap();
167
168        assert_eq!(view.len(), 4);
169        assert_eq!(*view, [item1, item2, item3, item4]);
170
171        // Remove from the middle
172        let removed = view.remove(1).unwrap();
173        assert_eq!(removed, item2);
174        assert_eq!(view.len(), 3);
175        assert_eq!(*view, [item1, item3, item4]);
176
177        // Remove from the end
178        let removed = view.remove(2).unwrap();
179        assert_eq!(removed, item4);
180        assert_eq!(view.len(), 2);
181        assert_eq!(*view, [item1, item3]);
182
183        // Remove from the start
184        let removed = view.remove(0).unwrap();
185        assert_eq!(removed, item1);
186        assert_eq!(view.len(), 1);
187        assert_eq!(*view, [item3]);
188
189        // Remove the last element
190        let removed = view.remove(0).unwrap();
191        assert_eq!(removed, item3);
192        assert_eq!(view.len(), 0);
193        assert!(view.is_empty());
194        assert_eq!(*view, []);
195    }
196
197    #[test]
198    fn test_remove_out_of_bounds() {
199        let mut buffer = vec![];
200        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 2);
201
202        view.push(TestStruct::new(1, 10)).unwrap();
203        view.push(TestStruct::new(2, 20)).unwrap();
204
205        // Try to remove at index == len
206        let err = view.remove(2).unwrap_err();
207        assert_eq!(err, ProgramError::InvalidArgument);
208        assert_eq!(view.len(), 2); // Unchanged
209
210        // Try to remove at index > len
211        let err = view.remove(100).unwrap_err();
212        assert_eq!(err, ProgramError::InvalidArgument);
213        assert_eq!(view.len(), 2); // Unchanged
214
215        // Empty the view
216        view.remove(1).unwrap();
217        view.remove(0).unwrap();
218        assert!(view.is_empty());
219
220        // Try to remove from empty view
221        let err = view.remove(0).unwrap_err();
222        assert_eq!(err, ProgramError::InvalidArgument);
223    }
224
225    #[test]
226    fn test_iter_mut() {
227        let mut buffer = vec![];
228        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 4);
229
230        let item1 = TestStruct::new(1, 10);
231        let item2 = TestStruct::new(2, 20);
232        let item3 = TestStruct::new(3, 30);
233        view.push(item1).unwrap();
234        view.push(item2).unwrap();
235        view.push(item3).unwrap();
236
237        assert_eq!(view.len(), 3);
238        assert_eq!(view.capacity(), 4);
239
240        // Modify items using iter_mut
241        for item in view.iter_mut() {
242            item.a *= 10;
243        }
244
245        let expected_item1 = TestStruct::new(10, 10);
246        let expected_item2 = TestStruct::new(20, 20);
247        let expected_item3 = TestStruct::new(30, 30);
248
249        // Check that the underlying data is modified
250        assert_eq!(view.len(), 3);
251        assert_eq!(*view, [expected_item1, expected_item2, expected_item3]);
252
253        // Check that iter_mut only iterates over `len` items, not `capacity`
254        assert_eq!(view.iter_mut().count(), 3);
255    }
256
257    #[test]
258    fn test_iter_mut_empty() {
259        let mut buffer = vec![];
260        let mut view = init_view_mut::<TestStruct, PodU64>(&mut buffer, 5);
261
262        let mut count = 0;
263        for _ in view.iter_mut() {
264            count += 1;
265        }
266        assert_eq!(count, 0);
267        assert_eq!(view.iter_mut().next(), None);
268    }
269
270    #[test]
271    fn test_zero_capacity() {
272        let mut buffer = vec![];
273        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 0);
274
275        assert_eq!(view.len(), 0);
276        assert_eq!(view.capacity(), 0);
277        assert!(view.is_empty());
278
279        let err = view.push(TestStruct::new(1, 1)).unwrap_err();
280        assert_eq!(err, PodSliceError::BufferTooSmall.into());
281
282        let err = view.remove(0).unwrap_err();
283        assert_eq!(err, ProgramError::InvalidArgument);
284    }
285
286    #[test]
287    fn test_default_length_type() {
288        let capacity = 2;
289        let mut buffer = vec![];
290        let size = ListView::<TestStruct, PodU64>::size_of(capacity).unwrap();
291        buffer.resize(size, 0);
292
293        // Initialize the view *without* specifying L. The compiler uses the default.
294        let view = ListView::<TestStruct>::init(&mut buffer).unwrap();
295
296        // Check that the capacity is correct for a PodU64 length.
297        assert_eq!(view.capacity(), capacity);
298        assert_eq!(view.len(), 0);
299
300        // Verify the size of the length field.
301        assert_eq!(size_of_val(view.length), size_of::<PodU32>());
302    }
303
304    #[test]
305    fn test_bytes_used_and_allocated_mut() {
306        // capacity 3, start empty
307        let mut buffer = vec![];
308        let mut view = init_view_mut::<TestStruct, PodU16>(&mut buffer, 3);
309
310        // Empty view
311        assert_eq!(
312            view.bytes_used().unwrap(),
313            ListView::<TestStruct, PodU32>::size_of(0).unwrap()
314        );
315        assert_eq!(
316            view.bytes_allocated().unwrap(),
317            ListView::<TestStruct, PodU32>::size_of(view.capacity()).unwrap()
318        );
319
320        // After pushing elements
321        view.push(TestStruct::new(1, 2)).unwrap();
322        view.push(TestStruct::new(3, 4)).unwrap();
323        view.push(TestStruct::new(5, 6)).unwrap();
324        assert_eq!(
325            view.bytes_used().unwrap(),
326            ListView::<TestStruct, PodU32>::size_of(3).unwrap()
327        );
328        assert_eq!(
329            view.bytes_allocated().unwrap(),
330            ListView::<TestStruct, PodU32>::size_of(view.capacity()).unwrap()
331        );
332    }
333    #[test]
334    fn test_get_and_get_mut() {
335        let mut buffer = vec![];
336        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 3);
337
338        let item0 = TestStruct::new(1, 10);
339        let item1 = TestStruct::new(2, 20);
340        view.push(item0).unwrap();
341        view.push(item1).unwrap();
342
343        // Test get()
344        assert_eq!(view.first(), Some(&item0));
345        assert_eq!(view.get(1), Some(&item1));
346        assert_eq!(view.get(2), None); // out of bounds
347        assert_eq!(view.get(100), None); // way out of bounds
348
349        // Test get_mut() to modify an item
350        let modified_item0 = TestStruct::new(111, 110);
351        let item_ref = view.get_mut(0).unwrap();
352        *item_ref = modified_item0;
353
354        // Verify the modification
355        assert_eq!(view.first(), Some(&modified_item0));
356        assert_eq!(*view, [modified_item0, item1]);
357
358        // Test get_mut() out of bounds
359        assert_eq!(view.get_mut(2), None);
360    }
361
362    #[test]
363    fn test_mutable_access_via_indexing() {
364        let mut buffer = vec![];
365        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 3);
366
367        let item0 = TestStruct::new(1, 10);
368        let item1 = TestStruct::new(2, 20);
369        view.push(item0).unwrap();
370        view.push(item1).unwrap();
371
372        assert_eq!(view.len(), 2);
373
374        // Modify via the mutable slice
375        view[0].a = 99;
376
377        let expected_item0 = TestStruct::new(99, 10);
378        assert_eq!(view.first(), Some(&expected_item0));
379        assert_eq!(*view, [expected_item0, item1]);
380    }
381
382    #[test]
383    fn test_sort_by() {
384        let mut buffer = vec![];
385        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 5);
386
387        let item0 = TestStruct::new(5, 1);
388        let item1 = TestStruct::new(2, 2);
389        let item2 = TestStruct::new(5, 3);
390        let item3 = TestStruct::new(1, 4);
391        let item4 = TestStruct::new(2, 5);
392
393        view.push(item0).unwrap();
394        view.push(item1).unwrap();
395        view.push(item2).unwrap();
396        view.push(item3).unwrap();
397        view.push(item4).unwrap();
398
399        // Sort by `b` field in descending order.
400        view.sort_by(|a, b| b.b.cmp(&a.b));
401        let expected_order_by_b_desc = [
402            item4, // b: 5
403            item3, // b: 4
404            item2, // b: 3
405            item1, // b: 2
406            item0, // b: 1
407        ];
408        assert_eq!(*view, expected_order_by_b_desc);
409
410        // Now, sort by `a` in ascending order. A stable sort preserves the relative
411        // order of equal elements from the previous state of the list.
412        view.sort_by(|x, y| x.a.cmp(&y.a));
413
414        let expected_order_by_a_stable = [
415            item3, // a: 1
416            item4, // a: 2 (was before item1 in the previous state)
417            item1, // a: 2
418            item2, // a: 5 (was before item0 in the previous state)
419            item0, // a: 5
420        ];
421        assert_eq!(*view, expected_order_by_a_stable);
422    }
423}