s2n_quic_core/
slice.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use core::ops::{Deref, DerefMut};
5
6pub mod deque;
7
8/// Copies vectored slices from one slice into another
9///
10/// The number of copied items is limited by the minimum of the lengths of each of the slices.
11///
12/// Returns the number of entries that were copied
13#[inline]
14pub fn vectored_copy<A, B, T>(from: &[A], to: &mut [B]) -> usize
15where
16    A: Deref<Target = [T]>,
17    B: Deref<Target = [T]> + DerefMut,
18    T: Copy,
19{
20    zip_chunks(from, to, |a, b| {
21        b.copy_from_slice(a);
22    })
23}
24
25/// Zips entries from one slice to another
26///
27/// The number of copied items is limited by the minimum of the lengths of each of the slices.
28///
29/// Returns the number of entries that were processed
30#[inline]
31pub fn zip<A, At, B, Bt, F>(from: &[A], to: &mut [B], mut on_item: F) -> usize
32where
33    A: Deref<Target = [At]>,
34    B: Deref<Target = [Bt]> + DerefMut,
35    F: FnMut(&At, &mut Bt),
36{
37    zip_chunks(from, to, |a, b| {
38        for (a, b) in a.iter().zip(b) {
39            on_item(a, b);
40        }
41    })
42}
43
44/// Zips overlapping chunks from one slice to another
45///
46/// The number of copied items is limited by the minimum of the lengths of each of the slices.
47///
48/// Returns the number of entries that were processed
49#[inline]
50pub fn zip_chunks<A, At, B, Bt, F>(from: &[A], to: &mut [B], mut on_slice: F) -> usize
51where
52    A: Deref<Target = [At]>,
53    B: Deref<Target = [Bt]> + DerefMut,
54    F: FnMut(&[At], &mut [Bt]),
55{
56    let mut count = 0;
57
58    let mut from_index = 0;
59    let mut from_offset = 0;
60
61    let mut to_index = 0;
62    let mut to_offset = 0;
63
64    // The compiler isn't smart enough to remove all of the bounds checks so we resort to
65    // `get_unchecked`.
66    //
67    // https://godbolt.org/z/45cG1v
68
69    // iterate until we reach one of the ends
70    while from_index < from.len() && to_index < to.len() {
71        let from = unsafe {
72            // Safety: this length is already checked in the while condition
73            debug_assert!(from.len() > from_index);
74            from.get_unchecked(from_index)
75        };
76
77        let to = unsafe {
78            // Safety: this length is already checked in the while condition
79            debug_assert!(to.len() > to_index);
80            to.get_unchecked_mut(to_index)
81        };
82
83        {
84            // calculate the current views
85            let from = unsafe {
86                // Safety: the slice offsets are checked at the end of the while loop
87                debug_assert!(from.len() >= from_offset);
88                from.get_unchecked(from_offset..)
89            };
90
91            let to = unsafe {
92                // Safety: the slice offsets are checked at the end of the while loop
93                debug_assert!(to.len() >= to_offset);
94                to.get_unchecked_mut(to_offset..)
95            };
96
97            let len = from.len().min(to.len());
98
99            unsafe {
100                // Safety: by using the min of the two lengths we will never exceed
101                //         either slice's buffer
102                debug_assert!(from.len() >= len);
103                debug_assert!(to.len() >= len);
104
105                let at = from.get_unchecked(..len);
106                let bt = to.get_unchecked_mut(..len);
107
108                on_slice(at, bt);
109            }
110
111            // increment the offsets
112            from_offset += len;
113            to_offset += len;
114            count += len;
115        }
116
117        // check if the `from` is done
118        if from.len() == from_offset {
119            from_index += 1;
120            from_offset = 0;
121        }
122
123        // check if the `to` is done
124        if to.len() == to_offset {
125            to_index += 1;
126            to_offset = 0;
127        }
128    }
129
130    count
131}
132
133/// Deduplicates elements in a slice
134///
135/// # Note
136///
137/// Items must be sorted before performing this function
138#[inline]
139pub fn partition_dedup<T>(slice: &mut [T]) -> (&mut [T], &mut [T])
140where
141    T: PartialEq,
142{
143    // TODO replace with
144    // https://doc.rust-lang.org/std/primitive.slice.html#method.partition_dedup
145    // when stable
146    //
147    // For now, we've just inlined their implementation
148
149    let len = slice.len();
150    if len <= 1 {
151        return (slice, &mut []);
152    }
153
154    let ptr = slice.as_mut_ptr();
155    let mut next_read: usize = 1;
156    let mut next_write: usize = 1;
157
158    // SAFETY: the `while` condition guarantees `next_read` and `next_write`
159    // are less than `len`, thus are inside `self`. `prev_ptr_write` points to
160    // one element before `ptr_write`, but `next_write` starts at 1, so
161    // `prev_ptr_write` is never less than 0 and is inside the slice.
162    // This fulfils the requirements for dereferencing `ptr_read`, `prev_ptr_write`
163    // and `ptr_write`, and for using `ptr.add(next_read)`, `ptr.add(next_write - 1)`
164    // and `prev_ptr_write.offset(1)`.
165    //
166    // `next_write` is also incremented at most once per loop at most meaning
167    // no element is skipped when it may need to be swapped.
168    //
169    // `ptr_read` and `prev_ptr_write` never point to the same element. This
170    // is required for `&mut *ptr_read`, `&mut *prev_ptr_write` to be safe.
171    // The explanation is simply that `next_read >= next_write` is always true,
172    // thus `next_read > next_write - 1` is too.
173    unsafe {
174        // Avoid bounds checks by using raw pointers.
175        while next_read < len {
176            let ptr_read = ptr.add(next_read);
177            let prev_ptr_write = ptr.add(next_write - 1);
178            if *ptr_read != *prev_ptr_write {
179                if next_read != next_write {
180                    let ptr_write = prev_ptr_write.add(1);
181                    core::ptr::swap(ptr_read, ptr_write);
182                }
183                next_write += 1;
184            }
185            next_read += 1;
186        }
187    }
188
189    slice.split_at_mut(next_write)
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use crate::testing::InlineVec;
196    use bolero::check;
197
198    fn assert_eq_slices<A, B, T>(a: &[A], b: &[B])
199    where
200        A: Deref<Target = [T]>,
201        B: Deref<Target = [T]>,
202        T: PartialEq + core::fmt::Debug,
203    {
204        let a = a.iter().flat_map(|a| a.iter());
205        let b = b.iter().flat_map(|b| b.iter());
206
207        // make sure all of the values match
208        //
209        // Note: this doesn't use Iterator::eq, as the slice lengths may be different
210        for (a, b) in a.zip(b) {
211            assert_eq!(a, b);
212        }
213    }
214
215    #[test]
216    fn vectored_copy_test() {
217        let from = [
218            &[0][..],
219            &[1, 2, 3][..],
220            &[4, 5, 6, 7][..],
221            &[][..],
222            &[8, 9, 10, 11][..],
223        ];
224
225        for len in 0..6 {
226            let mut to = vec![vec![0; 2]; len];
227            let copied_len = vectored_copy(&from, &mut to);
228            assert_eq!(copied_len, len * 2);
229            assert_eq_slices(&from, &to);
230        }
231    }
232
233    const LEN: usize = if cfg!(kani) { 2 } else { 32 };
234
235    #[test]
236    #[cfg_attr(kani, kani::proof, kani::unwind(5), kani::solver(kissat))]
237    #[cfg_attr(miri, ignore)] // This test is too expensive for miri to complete in a reasonable amount of time
238    fn vectored_copy_fuzz_test() {
239        check!()
240            .with_type::<(
241                InlineVec<InlineVec<u8, LEN>, LEN>,
242                InlineVec<InlineVec<u8, LEN>, LEN>,
243            )>()
244            .cloned()
245            .for_each(|(from, mut to)| {
246                vectored_copy(&from, &mut to);
247                assert_eq_slices(&from, &to);
248            })
249    }
250}