slice_group_by/binary_group/
binary_group_by.rs

1use crate::offset_from;
2use std::cmp::Ordering::{Greater, Less};
3use std::slice::{from_raw_parts, from_raw_parts_mut};
4use std::{fmt, marker};
5
6macro_rules! binary_group_by {
7    (struct $name:ident, $elem:ty, $mkslice:ident) => {
8        impl<'a, T: 'a, P> $name<'a, T, P> {
9            #[inline]
10            pub fn is_empty(&self) -> bool {
11                self.ptr == self.end
12            }
13
14            #[inline]
15            pub fn remainder_len(&self) -> usize {
16                unsafe { offset_from(self.end, self.ptr) }
17            }
18        }
19
20        impl<'a, T: 'a, P> std::iter::Iterator for $name<'a, T, P>
21        where
22            P: FnMut(&T, &T) -> bool,
23        {
24            type Item = $elem;
25
26            #[inline]
27            fn next(&mut self) -> Option<Self::Item> {
28                if self.is_empty() {
29                    return None;
30                }
31
32                let first = unsafe { &*self.ptr };
33
34                let len = self.remainder_len();
35                let tail = unsafe { $mkslice(self.ptr.add(1), len - 1) };
36
37                let predicate = |x: &T| {
38                    if (self.predicate)(first, x) {
39                        Less
40                    } else {
41                        Greater
42                    }
43                };
44                let index = tail.binary_search_by(predicate).unwrap_err();
45
46                let left = unsafe { $mkslice(self.ptr, index + 1) };
47                self.ptr = unsafe { self.ptr.add(index + 1) };
48
49                Some(left)
50            }
51
52            fn size_hint(&self) -> (usize, Option<usize>) {
53                if self.is_empty() {
54                    return (0, Some(0));
55                }
56
57                let len = self.remainder_len();
58                (1, Some(len))
59            }
60
61            fn last(mut self) -> Option<Self::Item> {
62                self.next_back()
63            }
64        }
65
66        impl<'a, T: 'a, P> std::iter::DoubleEndedIterator for $name<'a, T, P>
67        where
68            P: FnMut(&T, &T) -> bool,
69        {
70            #[inline]
71            fn next_back(&mut self) -> Option<Self::Item> {
72                if self.is_empty() {
73                    return None;
74                }
75
76                let last = unsafe { &*self.end.sub(1) };
77
78                let len = self.remainder_len();
79                let head = unsafe { $mkslice(self.ptr, len - 1) };
80
81                let predicate = |x: &T| {
82                    if (self.predicate)(last, x) {
83                        Greater
84                    } else {
85                        Less
86                    }
87                };
88                let index = head.binary_search_by(predicate).unwrap_err();
89
90                let right = unsafe { $mkslice(self.ptr.add(index), len - index) };
91                self.end = unsafe { self.end.sub(len - index) };
92
93                Some(right)
94            }
95        }
96
97        impl<'a, T: 'a, P> std::iter::FusedIterator for $name<'a, T, P> where
98            P: FnMut(&T, &T) -> bool
99        {
100        }
101    };
102}
103
104/// An iterator that will return non-overlapping groups in the slice using *binary search*.
105///
106/// It will not necessarily gives contiguous elements to the predicate function.
107/// The predicate function should implement an order consistent with the sort order of the slice.
108pub struct BinaryGroupBy<'a, T, P> {
109    ptr: *const T,
110    end: *const T,
111    predicate: P,
112    _phantom: marker::PhantomData<&'a T>,
113}
114
115impl<'a, T: 'a, P> BinaryGroupBy<'a, T, P> {
116    pub fn new(slice: &'a [T], predicate: P) -> Self {
117        BinaryGroupBy {
118            ptr: slice.as_ptr(),
119            end: unsafe { slice.as_ptr().add(slice.len()) },
120            predicate,
121            _phantom: marker::PhantomData,
122        }
123    }
124}
125
126impl<'a, T: 'a, P> BinaryGroupBy<'a, T, P> {
127    /// Returns the remainder of the original slice that is going to be
128    /// returned by the iterator.
129    pub fn remainder(&self) -> &[T] {
130        let len = self.remainder_len();
131        unsafe { from_raw_parts(self.ptr, len) }
132    }
133}
134
135impl<'a, T: 'a + fmt::Debug, P> fmt::Debug for BinaryGroupBy<'a, T, P> {
136    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
137        f.debug_struct("BinaryGroupBy")
138            .field("remainder", &self.remainder())
139            .finish()
140    }
141}
142
143binary_group_by! { struct BinaryGroupBy, &'a [T], from_raw_parts }
144
145/// An iterator that will return non-overlapping *mutable* groups
146/// in the slice using *binary search*.
147///
148/// It will not necessarily gives contiguous elements to the predicate function.
149/// The predicate function should implement an order consistent with the sort order of the slice.
150pub struct BinaryGroupByMut<'a, T, P> {
151    ptr: *mut T,
152    end: *mut T,
153    predicate: P,
154    _phantom: marker::PhantomData<&'a mut T>,
155}
156
157impl<'a, T: 'a, P> BinaryGroupByMut<'a, T, P>
158where
159    P: FnMut(&T, &T) -> bool,
160{
161    pub fn new(slice: &'a mut [T], predicate: P) -> Self {
162        let ptr = slice.as_mut_ptr();
163        let end = unsafe { ptr.add(slice.len()) };
164        BinaryGroupByMut {
165            ptr,
166            end,
167            predicate,
168            _phantom: marker::PhantomData,
169        }
170    }
171}
172
173impl<'a, T: 'a, P> BinaryGroupByMut<'a, T, P> {
174    /// Returns the remainder of the original slice that is going to be
175    /// returned by the iterator.
176    pub fn into_remainder(self) -> &'a mut [T] {
177        let len = self.remainder_len();
178        unsafe { from_raw_parts_mut(self.ptr, len) }
179    }
180}
181
182impl<'a, T: 'a + fmt::Debug, P> fmt::Debug for BinaryGroupByMut<'a, T, P> {
183    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
184        let len = self.remainder_len();
185        let remainder = unsafe { from_raw_parts(self.ptr, len) };
186
187        f.debug_struct("BinaryGroupByMut")
188            .field("remainder", &remainder)
189            .finish()
190    }
191}
192
193binary_group_by! { struct BinaryGroupByMut, &'a mut [T], from_raw_parts_mut }