swift_check/arch/
simd_scan.rs

1#![allow(clippy::let_and_return)] // the contracts require this and without the `verify` feature
2                                  // these bindings will cause warnings.
3
4use crate::arch::{self, byte_ptr, simd_ptr, Vector};
5
6cfg_verify!(
7    use crate::arch::is_aligned;
8    use mirai_annotations::{checked_precondition, checked_postcondition};
9);
10
11mod end_ptr {
12    cfg_verify!(use super::checked_postcondition;);
13    use crate::arch::Ptr;
14
15    /// An immutable representation of the `data`'s upper bound
16    #[derive(Copy, Clone)]
17    #[repr(transparent)]
18    pub struct EndPtr(*const Ptr);
19
20    impl EndPtr {
21        #[inline(always)] #[must_use]
22        pub const unsafe fn new(data: &[u8]) -> Self {
23            Self ( (data.as_ptr().add(data.len())).cast() )
24        }
25    }
26
27    impl EndPtr {
28        #[cfg_attr(feature = "verify", contracts::ensures(self.0 == ret))]
29        #[cfg_attr(feature = "verify", contracts::ensures(self.0 == old(self.0)))]
30        #[inline(always)] #[must_use]
31        pub fn get(&self) -> *const Ptr {
32            self.0
33        }
34
35        /// Checks that the underlying pointer has not changed, this is less ensuring correctness
36        /// of the program, more ensuring that no future changes violate the immutability invariant
37        /// via adjustments to the contracts & this type.
38        #[cfg(feature = "verify")]
39        pub unsafe fn check(&self, data: &[u8]) -> bool {
40            super::byte_ptr(self.get()) == super::byte_ptr(Self::new(data).get())
41        }
42    }
43}
44
45macro_rules! check_end_ptr {
46    ($end_ptr:expr, $data:expr) => {
47        #[cfg(feature = "verify")]
48        assert!($end_ptr.check($data))
49    };
50}
51
52use end_ptr::EndPtr;
53
54#[cfg_attr(feature = "verify", contracts::ensures(x >= 0 -> x as usize == ret))]
55#[inline(always)] #[must_use]
56unsafe fn remove_sign(x: isize) -> usize {
57    x as usize
58}
59
60#[cfg_attr(feature = "verify", contracts::requires(l >= r))]
61#[inline(always)] #[must_use]
62unsafe fn offset_from(l: *const u8, r: *const u8) -> isize {
63    let ret = l.offset_from(r);
64    // `l` being greater than 'r' is a precondition, therefore the offset will always be positive.
65    contract!(assumed_postcondition!(ret >= 0));
66    ret
67}
68
69#[cfg_attr(feature = "verify", contracts::requires(l >= r))]
70#[cfg_attr(feature = "verify", contracts::ensures(ret == remove_sign(offset_from(l, r))))]
71#[inline(always)] #[must_use]
72unsafe fn distance(l: *const u8, r: *const u8) -> usize {
73    remove_sign(offset_from(l, r))
74}
75
76#[cfg_attr(feature = "verify", contracts::requires(l >= r))]
77#[cfg_attr(feature = "verify", contracts::requires(byte_ptr(l) >= byte_ptr(r)))]
78#[cfg_attr(feature = "verify", contracts::ensures(ret == distance(byte_ptr(l), byte_ptr(r))))]
79#[inline(always)] #[must_use]
80unsafe fn simd_distance(l: *const arch::Ptr, r: *const arch::Ptr) -> usize {
81    distance(byte_ptr(l), byte_ptr(r))
82}
83
84#[cfg_attr(feature = "verify", contracts::requires(dist <= arch::WIDTH))]
85#[cfg_attr(feature = "verify", contracts::ensures(
86    dist == distance(byte_ptr(_end.get()), cur) -> incr_ptr(simd_ptr(ret)) == _end.get(),
87    "If `dist` is the byte offset of `_end` to `cur` then incr_ptr(simd_ptr(ret)) equates to \
88    `_end`"
89))]
90#[inline(always)] #[must_use]
91unsafe fn make_space(cur: *const u8, dist: usize, _end: EndPtr) -> *const u8 {
92    cur.sub(arch::WIDTH - dist)
93}
94
95#[cfg_attr(feature = "verify", contracts::ensures(ptr == decr_ptr(ret)))]
96#[cfg_attr(feature = "verify", contracts::ensures(ptr.add(arch::STEP) == ret))]
97#[cfg_attr(feature = "verify", contracts::ensures(byte_ptr(decr_ptr(ret)) == byte_ptr(ptr)))]
98#[cfg_attr(feature = "verify", contracts::ensures(is_aligned(ptr) -> is_aligned(ret)))]
99#[inline(always)] #[must_use]
100unsafe fn incr_ptr(ptr: *const arch::Ptr) -> *const arch::Ptr {
101    ptr.add(arch::STEP)
102}
103
104#[cfg_attr(feature = "verify", contracts::ensures(incr_ptr(ret) == ptr))]
105#[cfg_attr(feature = "verify", contracts::ensures(is_aligned(ptr) -> is_aligned(ret)))]
106#[inline(always)] #[must_use]
107unsafe fn decr_ptr(ptr: *const arch::Ptr) -> *const arch::Ptr {
108    ptr.sub(arch::STEP)
109}
110
111#[cfg_attr(feature = "verify", contracts::ensures(
112    ptr.align_offset(arch::WIDTH) != 0
113        -> ret.1.cast::<u8>().offset_from(ptr) as usize == ptr.align_offset(arch::WIDTH)
114))]
115#[cfg_attr(feature = "verify", contracts::ensures(
116    ptr.align_offset(arch::WIDTH) == 0 -> ret.1 == incr_ptr(simd_ptr(ptr))
117))]
118#[inline(always)] #[must_use]
119unsafe fn align_ptr_or_incr(ptr: *const u8) -> (Vector, *const arch::Ptr) {
120    match ptr.align_offset(arch::WIDTH) {
121        0 => {
122            let simd_ptr = simd_ptr(ptr);
123            // When the pointer is already aligned, increment it by `arch::STEP`
124            (arch::load_aligned(simd_ptr), incr_ptr(simd_ptr))
125        },
126        offset => {
127            // When the pointer is not aligned, adjust it to the next alignment boundary
128            (arch::load_unchecked(simd_ptr(ptr)), simd_ptr(ptr.add(offset)))
129        }
130    }
131}
132
133#[cfg_attr(feature = "verify", contracts::requires(end >= cur))]
134#[cfg_attr(feature = "verify", contracts::ensures(is_aligned(cur) -> is_aligned(cur)))]
135#[cfg_attr(feature = "verify", contracts::ensures(ret -> incr_ptr(cur) <= end))]
136#[inline(always)] #[must_use]
137unsafe fn can_proceed(cur: *const arch::Ptr, end: *const arch::Ptr) -> bool {
138    cur <= decr_ptr(end)
139}
140
141mod sealed {
142    use super::*;
143    cfg_verify!(use contracts::invariant;);
144
145    /// Initiate the scanning process
146    ///
147    /// # Returns
148    ///
149    /// 0. The first `Vector` associated with `data` which most likely will not be included
150    ///    in the `AlignedIter`, so it must be operated on independently. This is to align the
151    ///    pointer, enabling aligned loads for a performance enhancement.
152    /// 1. The `AlignedIter` which will handle loading all data after the initial `Vector`
153    ///    (`0`).
154    #[cfg_attr(feature = "verify", contracts::requires(data.len() >= arch::WIDTH))]
155    #[inline(always)] #[must_use]
156    pub unsafe fn init_scan(data: &[u8]) -> (Vector, AlignedIter) {
157        let (vector, aligned_ptr) = align_ptr_or_incr(data.as_ptr());
158        (
159            vector,
160            AlignedIter::after_first(aligned_ptr, data)
161        )
162    }
163
164    pub struct AlignedIter {
165        cur: *const arch::Ptr,
166        // `EndPtr` cannot be mutated so is it safe to expose.
167        pub end: EndPtr,
168    }
169
170    pub type Remainder = (Vector, *const arch::Ptr);
171
172    pub enum Pointer {
173        Aligned((Vector, *const arch::Ptr)),
174        End(Option<Remainder>)
175    }
176
177    #[cfg(feature = "verify")]
178    impl Pointer {
179        #[must_use]
180        pub const fn is_aligned(&self) -> bool {
181            matches!(self, Self::Aligned(_))
182        }
183        #[must_use]
184        pub const fn is_end_with_remaining(&self) -> bool {
185            matches!(self, Self::End(Some(_)))
186        }
187        #[must_use]
188        fn remaining_end_ptr(&self) -> *const arch::Ptr {
189            let Self::End(Some((_, ptr))) = self else {
190                unreachable!(
191                    "`remaining_end_ptr` called when state was not `End` with `Some` remainder"
192                );
193            };
194
195            *ptr
196        }
197    }
198
199    #[cfg_attr(feature = "verify", invariant(incr_ptr(self.cur) <= self.end.get()))]
200    impl AlignedIter {
201        #[cfg_attr(feature = "verify", contracts::requires(
202            is_aligned(aligned_ptr),
203            "To create an `AlignedIter` the `cur` pointer must be aligned to the `arch::WIDTH`"
204        ))]
205        #[inline(always)] #[must_use]
206        unsafe fn after_first(aligned_ptr: *const arch::Ptr, data: &[u8]) -> Self {
207            Self { cur: aligned_ptr, end: EndPtr::new(data) }
208        }
209
210        #[cfg_attr(feature = "verify", contracts::ensures(is_aligned(ret)))]
211        #[cfg_attr(feature = "verify", contracts::ensures(incr_ptr(ret) <= self.end.get()))]
212        #[inline(always)] #[must_use]
213        pub unsafe fn snap(&self) -> *const arch::Ptr {
214            self.cur
215        }
216
217        #[cfg_attr(feature = "verify", contracts::ensures(is_aligned(ret)))]
218        #[cfg_attr(feature = "verify", contracts::ensures(incr_ptr(ret) <= self.end.get()))]
219        #[inline(always)] #[must_use]
220        pub unsafe fn snap_and_incr(&mut self) -> *const arch::Ptr {
221            let ret = self.snap();
222            self.cur = incr_ptr(ret);
223            ret
224        }
225    }
226
227    #[cfg_attr(feature = "verify", invariant(self.end.get() >= self.cur))]
228    impl AlignedIter {
229        #[cfg_attr(feature = "verify", contracts::ensures(
230            ret.is_aligned() -> incr_ptr(self.cur) <= self.end.get()
231        ))]
232        #[cfg_attr(feature = "verify", contracts::ensures(
233            ret.is_end_with_remaining()
234                -> incr_ptr(ret.remaining_end_ptr()) == self.end.get()
235        ))]
236        #[inline(always)] #[must_use]
237        pub unsafe fn next(&mut self) -> Pointer {
238            if can_proceed(self.cur, self.end.get()) {
239                Pointer::Aligned({
240                    let ptr = self.snap_and_incr();
241                    (arch::load_aligned(ptr), ptr)
242                })
243            } else {
244                // As `can_proceed` failed and our invariant requires `end` to be greater than
245                // `cur` we know `distance(byte_ptr(self.end), byte_ptr(self.cur))` is less than
246                // `arch::WIDTH`
247                Pointer::End(self.end())
248            }
249        }
250
251        /// Handle the potential unaligned remaining bytes of `data`
252        ///
253        /// # Unchecked Precondition
254        ///
255        /// The [`AlignedIter::next`] must be mutable, unfortunately this complicates / makes the
256        /// expression of the precondition infeasible. That is why this method **cannot be exposed**
257        /// outside the `sealed` module.
258        ///
259        /// <br>
260        ///
261        /// This method is only used within [`AlignedIter::next`] where [`can_proceed`] fails.
262        /// [`can_proceed`]'s postcondition states by contradiction that if false then the distance
263        /// between `cur` and `end` is less than `arch::WIDTH`, this postcondition in conjunction
264        /// with the invariant `end >= cur` guarantees that the distance between `end` and `cur` is
265        /// less than `arch::WIDTH`, ensuring that the preconditions of [`make_space`] hold.
266        ///
267        /// # Returns
268        ///
269        /// * `Some(Remainder)` - Indicating that it was impossible to scan all of `data` with
270        ///   aligned loads and that there was a remainder. The pointer in `Remainder` is guaranteed
271        ///   to be exactly `arch::WIDTH` less than `end`, with the vector representing the final 16
272        ///   bytes of `data`
273        /// * `None` - There was no remainder and the scan can be considered completed.
274        ///
275        /// ### Note
276        ///
277        /// As this does not mutate the iterator it is safe to be called multiple times as long as
278        /// the invariants of the encapsulated pointers are not violated. Though there is no reason
279        /// to do this nor should it be done.
280        #[cfg_attr(feature = "verify", contracts::ensures(
281            ret.is_some() -> incr_ptr(ret.unwrap().1) == self.end.get()
282        ))]
283        #[inline(always)]
284        unsafe fn end(&self) -> Option<Remainder> {
285            match simd_distance(self.end.get(), self.cur) {
286                0 => None,
287                dist => {
288                    contract!(checked_assume!(dist <= arch::WIDTH));
289                    let ptr = simd_ptr(make_space(
290                        byte_ptr(self.cur), dist, self.end
291                    ));
292                    Some((arch::load_unchecked(ptr), ptr))
293                }
294            }
295        }
296    }
297}
298
299#[cfg_attr(feature = "verify", contracts::ensures(ret -> len < arch::WIDTH as u32))]
300#[cfg_attr(feature = "verify", contracts::ensures(!ret -> len >= arch::WIDTH as u32))]
301#[inline(always)] #[must_use]
302fn valid_len(len: u32) -> bool {
303    len < arch::WIDTH as u32
304}
305
306#[cfg_attr(feature = "verify", contracts::ensures(x == ret as u32))]
307#[cfg_attr(feature = "verify", contracts::ensures(ret == x as usize))]
308#[inline(always)] #[must_use]
309fn u32_as_usize(x: u32) -> usize {
310    x as usize
311}
312
313/// Post-condition: As long as the preconditions are respected the returned value will always be
314/// less than `data.len()`
315#[cfg_attr(feature = "verify", contracts::requires(data.len() >= arch::WIDTH))]
316#[cfg_attr(feature = "verify", contracts::requires(
317    valid_len(len),
318    "The length must be below the SIMD register width, it being outside of this range denotes that \
319     find operation did not succeed."
320))]
321#[cfg_attr(feature = "verify", contracts::requires(
322    cur >= data.as_ptr(),
323    "The `cur` pointer must not have moved backwards beyond the start of `data`"
324))]
325#[cfg_attr(feature = "verify", contracts::requires(
326    u32_as_usize(len) < usize::MAX - distance(cur, data.as_ptr()),
327    "The length + the distance from `cur` to `data` must not be able to overflow."
328))]
329#[cfg_attr(feature = "verify", contracts::requires(
330    incr_ptr(simd_ptr(cur)) <= _end.get(),
331    "The distance between `cur` and `data` must be less than the data's length subtracted by the \
332     SIMD register width."
333))]
334#[inline(always)] #[must_use]
335unsafe fn final_length(len: u32, cur: *const u8, data: &[u8], _end: EndPtr) -> usize {
336    let ret = u32_as_usize(len).wrapping_add(distance(cur, data.as_ptr()));
337    // Relevant Preconditions:
338    //
339    // P(1) `incr_ptr(simd_ptr(cur)) <= _end.get()`: Guarantees that the distance between `cur`
340    //   and `data.as_ptr()` is less than `arch::WIDTH` as `EndPtr` is an immutable representation
341    //   of `data`'s upper bound.
342    // P(2) `valid_len(len)`: Guarantees that the `len` is less than `arch::WIDTH`
343    //
344    // Therefore the sum of `len` and `distance(cur, data.as_ptr)` is guaranteed less than the
345    // data's length.
346    contract!(assumed_postcondition!(ret < data.len()));
347    ret
348}
349
350macro_rules! valid_len_then {
351    ($len:ident, $do:expr $(, $otherwise:expr)?) => {
352        if valid_len($len) {
353            // Re-emphasize postcondition of `valid_len`
354            contract!(debug_checked_assume!(valid_len($len)));
355            $do
356        } $( else {
357            $otherwise
358        })?
359    };
360}
361
362#[cfg_attr(feature = "verify", contracts::requires(data.len() >= arch::WIDTH))]
363#[cfg_attr(feature = "verify", contracts::ensures(ret.is_some() -> ret.unwrap() < data.len()))]
364#[inline(always)]
365pub unsafe fn search<F: Fn(Vector) -> Vector>(data: &[u8], cond: F) -> Option<usize> {
366    let (vector, mut iter) = sealed::init_scan(data);
367
368    let len = arch::MoveMask::new(cond(vector)).trailing_zeros();
369    if valid_len(len) { return Some(len as usize); }
370
371    loop {
372        match iter.next() {
373            sealed::Pointer::Aligned((vector, ptr)) => {
374                check_end_ptr!(iter.end, data);
375                let len = arch::MoveMask::new(cond(vector)).trailing_zeros();
376                valid_len_then!(
377                    len,
378                    break Some(final_length(len, byte_ptr(ptr), data, iter.end))
379                );
380            },
381            sealed::Pointer::End(Some((vector, ptr))) => {
382                check_end_ptr!(iter.end, data);
383                let len = arch::MoveMask::new(cond(vector)).trailing_zeros();
384                break valid_len_then!(
385                    len,
386                    Some(final_length(len, byte_ptr(ptr), data, iter.end)),
387                    None
388                );
389            },
390            sealed::Pointer::End(None) => {
391                check_end_ptr!(iter.end, data);
392                break None;
393            }
394        }
395    }
396}
397
398#[cfg_attr(feature = "verify", contracts::requires(data.len() >= arch::WIDTH))]
399#[inline(always)]
400pub unsafe fn for_all_ensure_ct<F: Fn(Vector) -> Vector>(data: &[u8], cond: F, res: &mut bool) {
401    let (vector, mut iter) = sealed::init_scan(data);
402    *res &= crate::ensure!(vector, cond);
403
404    loop {
405        match iter.next() {
406            sealed::Pointer::Aligned((vector, _)) => {
407                check_end_ptr!(iter.end, data);
408                *res &= crate::ensure!(vector, cond);
409            },
410            sealed::Pointer::End(Some((vector, _))) => {
411                check_end_ptr!(iter.end, data);
412                *res &= crate::ensure!(vector, cond);
413                break;
414            },
415            sealed::Pointer::End(None) => {
416                check_end_ptr!(iter.end, data);
417                break;
418            }
419        }
420    }
421}
422
423#[cfg_attr(feature = "verify", contracts::requires(data.len() >= arch::WIDTH))]
424#[inline(always)] #[must_use]
425pub unsafe fn for_all_ensure<F: Fn(Vector) -> Vector>(data: &[u8], cond: F) -> bool {
426    let (vector, mut iter) = sealed::init_scan(data);
427    if !crate::ensure!(vector, cond) { return false; }
428
429    loop {
430        match iter.next() {
431            sealed::Pointer::Aligned((vector, _)) => {
432                check_end_ptr!(iter.end, data);
433                if !crate::ensure!(vector, cond) { break false; }
434            },
435            sealed::Pointer::End(Some((vector, _))) => {
436                check_end_ptr!(iter.end, data);
437                break crate::ensure!(vector, cond);
438            },
439            sealed::Pointer::End(None) => {
440                check_end_ptr!(iter.end, data);
441                break true;
442            }
443        }
444    }
445}
446
447#[cfg(feature = "require")]
448#[cfg_attr(feature = "verify", contracts::requires(data.len() >= arch::WIDTH))]
449#[inline(always)]
450pub unsafe fn ensure_requirements<R: crate::require::Requirement>(data: &[u8], mut req: R) -> R {
451    let (vector, mut iter) = sealed::init_scan(data);
452    req.check(vector);
453
454    loop {
455        match iter.next() {
456            sealed::Pointer::Aligned((vector, _)) => {
457                check_end_ptr!(iter.end, data);
458                req.check(vector);
459            },
460            sealed::Pointer::End(Some((vector, _))) => {
461                check_end_ptr!(iter.end, data);
462                req.check(vector);
463                break req;
464            },
465            sealed::Pointer::End(None) => {
466                check_end_ptr!(iter.end, data);
467                break req;
468            }
469        }
470    }
471}