vortex_array/
patches.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::fmt::Debug;
6use std::hash::Hash;
7
8use arrow_buffer::BooleanBuffer;
9use itertools::Itertools as _;
10use num_traits::{NumCast, ToPrimitive};
11use serde::{Deserialize, Serialize};
12use vortex_buffer::BufferMut;
13use vortex_dtype::Nullability::NonNullable;
14use vortex_dtype::{
15    DType, NativePType, PType, match_each_integer_ptype, match_each_unsigned_integer_ptype,
16};
17use vortex_error::{
18    VortexError, VortexExpect, VortexResult, VortexUnwrap, vortex_bail, vortex_err,
19};
20use vortex_mask::{AllOr, Mask};
21use vortex_scalar::{PValue, Scalar};
22use vortex_utils::aliases::hash_map::HashMap;
23
24use crate::arrays::PrimitiveArray;
25use crate::compute::{cast, filter, take};
26use crate::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
27use crate::vtable::ValidityHelper;
28use crate::{Array, ArrayRef, IntoArray, ToCanonical};
29
30#[derive(Copy, Clone, Serialize, Deserialize, prost::Message)]
31pub struct PatchesMetadata {
32    #[prost(uint64, tag = "1")]
33    len: u64,
34    #[prost(uint64, tag = "2")]
35    offset: u64,
36    #[prost(enumeration = "PType", tag = "3")]
37    indices_ptype: i32,
38}
39
40impl PatchesMetadata {
41    pub fn new(len: usize, offset: usize, indices_ptype: PType) -> Self {
42        Self {
43            len: len as u64,
44            offset: offset as u64,
45            indices_ptype: indices_ptype as i32,
46        }
47    }
48
49    #[inline]
50    pub fn len(&self) -> usize {
51        usize::try_from(self.len).vortex_expect("len is a valid usize")
52    }
53
54    #[inline]
55    pub fn is_empty(&self) -> bool {
56        self.len == 0
57    }
58
59    #[inline]
60    pub fn offset(&self) -> usize {
61        usize::try_from(self.offset).vortex_expect("offset is a valid usize")
62    }
63
64    #[inline]
65    pub fn indices_dtype(&self) -> DType {
66        assert!(
67            self.indices_ptype().is_unsigned_int(),
68            "Patch indices must be unsigned integers"
69        );
70        DType::Primitive(self.indices_ptype(), NonNullable)
71    }
72}
73
74/// A helper for working with patched arrays.
75#[derive(Debug, Clone)]
76pub struct Patches {
77    array_len: usize,
78    offset: usize,
79    indices: ArrayRef,
80    values: ArrayRef,
81}
82
83impl Patches {
84    pub fn new(array_len: usize, offset: usize, indices: ArrayRef, values: ArrayRef) -> Self {
85        assert_eq!(
86            indices.len(),
87            values.len(),
88            "Patch indices and values must have the same length"
89        );
90        assert!(
91            indices.dtype().is_unsigned_int() && !indices.dtype().is_nullable(),
92            "Patch indices must be non-nullable unsigned integers, got {:?}",
93            indices.dtype()
94        );
95        assert!(
96            indices.len() <= array_len,
97            "Patch indices must be shorter than the array length"
98        );
99        assert!(!indices.is_empty(), "Patch indices must not be empty");
100        let max = usize::try_from(&indices.scalar_at(indices.len() - 1))
101            .vortex_expect("indices must be a number");
102        assert!(
103            max - offset < array_len,
104            "Patch indices {max:?}, offset {offset} are longer than the array length {array_len}"
105        );
106
107        Self {
108            array_len,
109            offset,
110            indices,
111            values,
112        }
113    }
114
115    /// Construct new patches without validating any of the arguments
116    ///
117    /// # Safety
118    ///
119    /// Users have to assert that
120    /// * Indices and values have the same length
121    /// * Indices is an unsigned integer type
122    /// * Indices must be sorted
123    /// * Last value in indices is smaller than array_len
124    pub unsafe fn new_unchecked(
125        array_len: usize,
126        offset: usize,
127        indices: ArrayRef,
128        values: ArrayRef,
129    ) -> Self {
130        Self {
131            array_len,
132            offset,
133            indices,
134            values,
135        }
136    }
137
138    pub fn array_len(&self) -> usize {
139        self.array_len
140    }
141
142    pub fn num_patches(&self) -> usize {
143        self.indices.len()
144    }
145
146    pub fn dtype(&self) -> &DType {
147        self.values.dtype()
148    }
149
150    pub fn indices(&self) -> &ArrayRef {
151        &self.indices
152    }
153
154    pub fn into_indices(self) -> ArrayRef {
155        self.indices
156    }
157
158    pub fn indices_mut(&mut self) -> &mut ArrayRef {
159        &mut self.indices
160    }
161
162    pub fn values(&self) -> &ArrayRef {
163        &self.values
164    }
165
166    pub fn into_values(self) -> ArrayRef {
167        self.values
168    }
169
170    pub fn values_mut(&mut self) -> &mut ArrayRef {
171        &mut self.values
172    }
173
174    pub fn offset(&self) -> usize {
175        self.offset
176    }
177
178    pub fn indices_ptype(&self) -> PType {
179        PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
180    }
181
182    pub fn to_metadata(&self, len: usize, dtype: &DType) -> VortexResult<PatchesMetadata> {
183        if self.indices.len() > len {
184            vortex_bail!(
185                "Patch indices {} are longer than the array length {}",
186                self.indices.len(),
187                len
188            );
189        }
190        if self.values.dtype() != dtype {
191            vortex_bail!(
192                "Patch values dtype {} does not match array dtype {}",
193                self.values.dtype(),
194                dtype
195            );
196        }
197        Ok(PatchesMetadata {
198            len: self.indices.len() as u64,
199            offset: self.offset as u64,
200            indices_ptype: PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
201                as i32,
202        })
203    }
204
205    pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
206        // SAFETY: casting does not affect the relationship between the indices and values
207        unsafe {
208            Ok(Self::new_unchecked(
209                self.array_len,
210                self.offset,
211                self.indices,
212                cast(&self.values, values_dtype)?,
213            ))
214        }
215    }
216
217    /// Get the patched value at a given index if it exists.
218    pub fn get_patched(&self, index: usize) -> Option<Scalar> {
219        self.search_index(index)
220            .to_found()
221            .map(|patch_idx| self.values().scalar_at(patch_idx))
222    }
223
224    /// Return the insertion point of `index` in the [Self::indices].
225    pub fn search_index(&self, index: usize) -> SearchResult {
226        self.indices.as_primitive_typed().search_sorted(
227            &PValue::U64((index + self.offset) as u64),
228            SearchSortedSide::Left,
229        )
230    }
231
232    /// Return the search_sorted result for the given target re-mapped into the original indices.
233    pub fn search_sorted<T: Into<Scalar>>(
234        &self,
235        target: T,
236        side: SearchSortedSide,
237    ) -> VortexResult<SearchResult> {
238        let target = target.into();
239
240        let sr = if self.values().dtype().is_primitive() {
241            self.values()
242                .as_primitive_typed()
243                .search_sorted(&target.as_primitive().pvalue(), side)
244        } else {
245            self.values().search_sorted(&target, side)
246        };
247
248        let index_idx = sr.to_offsets_index(self.indices().len(), side);
249        let index = usize::try_from(&self.indices().scalar_at(index_idx))? - self.offset;
250        Ok(match sr {
251            // If we reached the end of patched values when searching then the result is one after the last patch index
252            SearchResult::Found(i) => SearchResult::Found(
253                if i == self.indices().len() || side == SearchSortedSide::Right {
254                    index + 1
255                } else {
256                    index
257                },
258            ),
259            // If the result is NotFound we should return index that's one after the nearest not found index for the corresponding value
260            SearchResult::NotFound(i) => {
261                SearchResult::NotFound(if i == 0 { index } else { index + 1 })
262            }
263        })
264    }
265
266    /// Returns the minimum patch index
267    pub fn min_index(&self) -> usize {
268        let first = self
269            .indices
270            .scalar_at(0)
271            .as_primitive()
272            .as_::<usize>()
273            .vortex_expect("non-null");
274        first - self.offset
275    }
276
277    /// Returns the maximum patch index
278    pub fn max_index(&self) -> usize {
279        let last = self
280            .indices
281            .scalar_at(self.indices.len() - 1)
282            .as_primitive()
283            .as_::<usize>()
284            .vortex_expect("non-null");
285        last - self.offset
286    }
287
288    /// Filter the patches by a mask, resulting in new patches for the filtered array.
289    pub fn filter(&self, mask: &Mask) -> VortexResult<Option<Self>> {
290        if mask.len() != self.array_len {
291            vortex_bail!(
292                "Filter mask length {} does not match array length {}",
293                mask.len(),
294                self.array_len
295            );
296        }
297
298        match mask.indices() {
299            AllOr::All => Ok(Some(self.clone())),
300            AllOr::None => Ok(None),
301            AllOr::Some(mask_indices) => {
302                let flat_indices = self.indices().to_primitive()?;
303                match_each_unsigned_integer_ptype!(flat_indices.ptype(), |I| {
304                    filter_patches_with_mask(
305                        flat_indices.as_slice::<I>(),
306                        self.offset(),
307                        self.values(),
308                        mask_indices,
309                    )
310                })
311            }
312        }
313    }
314
315    /// Mask the patches, REMOVING the patches where the mask is true.
316    /// Unlike filter, this preserves the patch indices.
317    /// Unlike mask on a single array, this does not set masked values to null.
318    pub fn mask(&self, mask: &Mask) -> VortexResult<Option<Self>> {
319        if mask.len() != self.array_len {
320            vortex_bail!(
321                "Filter mask length {} does not match array length {}",
322                mask.len(),
323                self.array_len
324            );
325        }
326
327        let filter_mask = match mask.boolean_buffer() {
328            AllOr::All => return Ok(None),
329            AllOr::None => return Ok(Some(self.clone())),
330            AllOr::Some(masked) => {
331                let patch_indices = self.indices().to_primitive()?;
332                match_each_unsigned_integer_ptype!(patch_indices.ptype(), |P| {
333                    let patch_indices = patch_indices.as_slice::<P>();
334                    Mask::from_buffer(BooleanBuffer::collect_bool(patch_indices.len(), |i| {
335                        #[allow(clippy::cast_possible_truncation)]
336                        let idx = (patch_indices[i] as usize) - self.offset;
337                        !masked.value(idx)
338                    }))
339                })
340            }
341        };
342
343        if filter_mask.all_false() {
344            return Ok(None);
345        }
346
347        // SAFETY: filtering indices/values with same mask maintains their 1:1 relationship
348        unsafe {
349            Ok(Some(Self::new_unchecked(
350                self.array_len,
351                self.offset,
352                filter(&self.indices, &filter_mask)?,
353                filter(&self.values, &filter_mask)?,
354            )))
355        }
356    }
357
358    /// Slice the patches by a range of the patched array.
359    pub fn slice(&self, start: usize, stop: usize) -> Option<Self> {
360        let patch_start = self.search_index(start).to_index();
361        let patch_stop = self.search_index(stop).to_index();
362
363        if patch_start == patch_stop {
364            return None;
365        }
366
367        // Slice out the values and indices
368        let values = self.values().slice(patch_start, patch_stop);
369        let indices = self.indices().slice(patch_start, patch_stop);
370
371        Some(Self::new(
372            stop - start,
373            start + self.offset(),
374            indices,
375            values,
376        ))
377    }
378
379    // https://docs.google.com/spreadsheets/d/1D9vBZ1QJ6mwcIvV5wIL0hjGgVchcEnAyhvitqWu2ugU
380    const PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN: f64 = 5.0;
381
382    fn is_map_faster_than_search(&self, take_indices: &PrimitiveArray) -> bool {
383        (self.num_patches() as f64 / take_indices.len() as f64)
384            < Self::PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN
385    }
386
387    /// Take the indices from the patches
388    ///
389    /// Any nulls in take_indices are added to the resulting patches.
390    pub fn take_with_nulls(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
391        if take_indices.is_empty() {
392            return Ok(None);
393        }
394
395        let take_indices = take_indices.to_primitive()?;
396        if self.is_map_faster_than_search(&take_indices) {
397            self.take_map(take_indices, true)
398        } else {
399            self.take_search(take_indices, true)
400        }
401    }
402
403    /// Take the indices from the patches.
404    ///
405    /// Any nulls in take_indices are ignored.
406    pub fn take(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
407        if take_indices.is_empty() {
408            return Ok(None);
409        }
410
411        let take_indices = take_indices.to_primitive()?;
412        if self.is_map_faster_than_search(&take_indices) {
413            self.take_map(take_indices, false)
414        } else {
415            self.take_search(take_indices, false)
416        }
417    }
418
419    pub fn take_search(
420        &self,
421        take_indices: PrimitiveArray,
422        include_nulls: bool,
423    ) -> VortexResult<Option<Self>> {
424        let indices = self.indices.to_primitive()?;
425        let new_length = take_indices.len();
426
427        let Some((new_indices, values_indices)) =
428            match_each_unsigned_integer_ptype!(indices.ptype(), |Indices| {
429                match_each_integer_ptype!(take_indices.ptype(), |TakeIndices| {
430                    take_search::<_, TakeIndices>(
431                        indices.as_slice::<Indices>(),
432                        take_indices,
433                        self.offset(),
434                        include_nulls,
435                    )?
436                })
437            })
438        else {
439            return Ok(None);
440        };
441
442        Ok(Some(Self::new(
443            new_length,
444            0,
445            new_indices,
446            take(self.values(), &values_indices)?,
447        )))
448    }
449
450    pub fn take_map(
451        &self,
452        take_indices: PrimitiveArray,
453        include_nulls: bool,
454    ) -> VortexResult<Option<Self>> {
455        let indices = self.indices.to_primitive()?;
456        let new_length = take_indices.len();
457
458        let Some((new_sparse_indices, value_indices)) =
459            match_each_unsigned_integer_ptype!(indices.ptype(), |Indices| {
460                match_each_integer_ptype!(take_indices.ptype(), |TakeIndices| {
461                    take_map::<_, TakeIndices>(
462                        indices.as_slice::<Indices>(),
463                        take_indices,
464                        self.offset(),
465                        self.min_index(),
466                        self.max_index(),
467                        include_nulls,
468                    )?
469                })
470            })
471        else {
472            return Ok(None);
473        };
474
475        Ok(Some(Patches::new(
476            new_length,
477            0,
478            new_sparse_indices,
479            take(self.values(), &value_indices)?,
480        )))
481    }
482
483    pub fn map_values<F>(self, f: F) -> VortexResult<Self>
484    where
485        F: FnOnce(ArrayRef) -> VortexResult<ArrayRef>,
486    {
487        let values = f(self.values)?;
488        if self.indices.len() != values.len() {
489            vortex_bail!(
490                "map_values must preserve length: expected {} received {}",
491                self.indices.len(),
492                values.len()
493            )
494        }
495        Ok(Self::new(self.array_len, self.offset, self.indices, values))
496    }
497}
498
499fn take_search<I: NativePType + NumCast + PartialOrd, T: NativePType + NumCast>(
500    indices: &[I],
501    take_indices: PrimitiveArray,
502    indices_offset: usize,
503    include_nulls: bool,
504) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
505where
506    usize: TryFrom<T>,
507    VortexError: From<<usize as TryFrom<T>>::Error>,
508{
509    let take_indices_validity = take_indices.validity();
510    let indices_offset = I::from(indices_offset).vortex_expect("indices_offset out of range");
511
512    let (values_indices, new_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
513        .as_slice::<T>()
514        .iter()
515        .enumerate()
516        .filter_map(|(i, &v)| {
517            I::from(v)
518                .and_then(|v| {
519                    // If we have to take nulls the take index doesn't matter, make it 0 for consistency
520                    if include_nulls && take_indices_validity.is_null(i).vortex_unwrap() {
521                        Some(0)
522                    } else {
523                        indices
524                            .search_sorted(&(v + indices_offset), SearchSortedSide::Left)
525                            .to_found()
526                            .map(|patch_idx| patch_idx as u64)
527                    }
528                })
529                .map(|patch_idx| (patch_idx, i as u64))
530        })
531        .unzip();
532
533    if new_indices.is_empty() {
534        return Ok(None);
535    }
536
537    let new_indices = new_indices.into_array();
538    let values_validity = take_indices_validity.take(&new_indices)?;
539    Ok(Some((
540        new_indices,
541        PrimitiveArray::new(values_indices, values_validity).into_array(),
542    )))
543}
544
545fn take_map<I: NativePType + Hash + Eq + TryFrom<usize>, T: NativePType>(
546    indices: &[I],
547    take_indices: PrimitiveArray,
548    indices_offset: usize,
549    min_index: usize,
550    max_index: usize,
551    include_nulls: bool,
552) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
553where
554    usize: TryFrom<T>,
555    VortexError: From<<I as TryFrom<usize>>::Error>,
556{
557    let take_indices_validity = take_indices.validity();
558    let take_indices = take_indices.as_slice::<T>();
559    let offset_i = I::try_from(indices_offset)?;
560
561    let sparse_index_to_value_index: HashMap<I, usize> = indices
562        .iter()
563        .copied()
564        .map(|idx| idx - offset_i)
565        .enumerate()
566        .map(|(value_index, sparse_index)| (sparse_index, value_index))
567        .collect();
568
569    let (new_sparse_indices, value_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
570        .iter()
571        .copied()
572        .map(usize::try_from)
573        .process_results(|iter| {
574            iter.enumerate()
575                .filter_map(|(idx_in_take, ti)| {
576                    // If we have to take nulls the take index doesn't matter, make it 0 for consistency
577                    if include_nulls && take_indices_validity.is_null(idx_in_take).vortex_unwrap() {
578                        Some((idx_in_take as u64, 0))
579                    } else if ti < min_index || ti > max_index {
580                        None
581                    } else {
582                        sparse_index_to_value_index
583                            .get(
584                                &I::try_from(ti)
585                                    .vortex_expect("take index is between min and max index"),
586                            )
587                            .map(|value_index| (idx_in_take as u64, *value_index as u64))
588                    }
589                })
590                .unzip()
591        })
592        .map_err(|_| vortex_err!("Failed to convert index to usize"))?;
593
594    if new_sparse_indices.is_empty() {
595        return Ok(None);
596    }
597
598    let new_sparse_indices = new_sparse_indices.into_array();
599    let values_validity = take_indices_validity.take(&new_sparse_indices)?;
600    Ok(Some((
601        new_sparse_indices,
602        PrimitiveArray::new(value_indices, values_validity).into_array(),
603    )))
604}
605
606/// Filter patches with the provided mask (in flattened space).
607///
608/// The filter mask may contain indices that are non-patched. The return value of this function
609/// is a new set of `Patches` with the indices relative to the provided `mask` rank, and the
610/// patch values.
611fn filter_patches_with_mask<T: ToPrimitive + Copy + Ord>(
612    patch_indices: &[T],
613    offset: usize,
614    patch_values: &dyn Array,
615    mask_indices: &[usize],
616) -> VortexResult<Option<Patches>> {
617    let true_count = mask_indices.len();
618    let mut new_patch_indices = BufferMut::<u64>::with_capacity(true_count);
619    let mut new_mask_indices = Vec::with_capacity(true_count);
620
621    // Attempt to move the window by `STRIDE` elements on each iteration. This assumes that
622    // the patches are relatively sparse compared to the overall mask, and so many indices in the
623    // mask will end up being skipped.
624    const STRIDE: usize = 4;
625
626    let mut mask_idx = 0usize;
627    let mut true_idx = 0usize;
628
629    while mask_idx < patch_indices.len() && true_idx < true_count {
630        // NOTE: we are searching for overlaps between sorted, unaligned indices in `patch_indices`
631        //  and `mask_indices`. We assume that Patches are sparse relative to the global space of
632        //  the mask (which covers both patch and non-patch values of the parent array), and so to
633        //  quickly jump through regions with no overlap, we attempt to move our pointers by STRIDE
634        //  elements on each iteration. If we cannot rule out overlap due to min/max values, we
635        //  fallback to performing a two-way iterator merge.
636        if (mask_idx + STRIDE) < patch_indices.len() && (true_idx + STRIDE) < mask_indices.len() {
637            // Load a vector of each into our registers.
638            let left_min = patch_indices[mask_idx].to_usize().vortex_expect("left_min") - offset;
639            let left_max = patch_indices[mask_idx + STRIDE]
640                .to_usize()
641                .vortex_expect("left_max")
642                - offset;
643            let right_min = mask_indices[true_idx];
644            let right_max = mask_indices[true_idx + STRIDE];
645
646            if left_min > right_max {
647                // Advance right side
648                true_idx += STRIDE;
649                continue;
650            } else if right_min > left_max {
651                mask_idx += STRIDE;
652                continue;
653            } else {
654                // Fallthrough to direct comparison path.
655            }
656        }
657
658        // Two-way sorted iterator merge:
659
660        let left = patch_indices[mask_idx].to_usize().vortex_expect("left") - offset;
661        let right = mask_indices[true_idx];
662
663        match left.cmp(&right) {
664            Ordering::Less => {
665                mask_idx += 1;
666            }
667            Ordering::Greater => {
668                true_idx += 1;
669            }
670            Ordering::Equal => {
671                // Save the mask index as well as the positional index.
672                new_mask_indices.push(mask_idx);
673                new_patch_indices.push(true_idx as u64);
674
675                mask_idx += 1;
676                true_idx += 1;
677            }
678        }
679    }
680
681    if new_mask_indices.is_empty() {
682        return Ok(None);
683    }
684
685    let new_patch_indices = new_patch_indices.into_array();
686    let new_patch_values = filter(
687        patch_values,
688        &Mask::from_indices(patch_values.len(), new_mask_indices),
689    )?;
690
691    Ok(Some(Patches::new(
692        true_count,
693        0,
694        new_patch_indices,
695        new_patch_values,
696    )))
697}
698
699#[cfg(test)]
700mod test {
701    use rstest::{fixture, rstest};
702    use vortex_buffer::buffer;
703    use vortex_mask::Mask;
704
705    use crate::arrays::PrimitiveArray;
706    use crate::patches::Patches;
707    use crate::search_sorted::{SearchResult, SearchSortedSide};
708    use crate::validity::Validity;
709    use crate::{IntoArray, ToCanonical};
710
711    #[test]
712    fn test_filter() {
713        let patches = Patches::new(
714            100,
715            0,
716            buffer![10u32, 11, 20].into_array(),
717            buffer![100, 110, 200].into_array(),
718        );
719
720        let filtered = patches
721            .filter(&Mask::from_indices(100, vec![10, 20, 30]))
722            .unwrap()
723            .unwrap();
724
725        let indices = filtered.indices().to_primitive().unwrap();
726        let values = filtered.values().to_primitive().unwrap();
727        assert_eq!(indices.as_slice::<u64>(), &[0, 1]);
728        assert_eq!(values.as_slice::<i32>(), &[100, 200]);
729    }
730
731    #[fixture]
732    fn patches() -> Patches {
733        Patches::new(
734            20,
735            0,
736            buffer![2u64, 9, 15].into_array(),
737            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
738        )
739    }
740
741    #[rstest]
742    fn search_larger_than(patches: Patches) {
743        let res = patches.search_sorted(66, SearchSortedSide::Left).unwrap();
744        assert_eq!(res, SearchResult::NotFound(16));
745    }
746
747    #[rstest]
748    fn search_less_than(patches: Patches) {
749        let res = patches.search_sorted(22, SearchSortedSide::Left).unwrap();
750        assert_eq!(res, SearchResult::NotFound(2));
751    }
752
753    #[rstest]
754    fn search_found(patches: Patches) {
755        let res = patches.search_sorted(44, SearchSortedSide::Left).unwrap();
756        assert_eq!(res, SearchResult::Found(9));
757    }
758
759    #[rstest]
760    fn search_not_found_right(patches: Patches) {
761        let res = patches.search_sorted(56, SearchSortedSide::Right).unwrap();
762        assert_eq!(res, SearchResult::NotFound(16));
763    }
764
765    #[rstest]
766    fn search_sliced(patches: Patches) {
767        let sliced = patches.slice(7, 20).unwrap();
768        assert_eq!(
769            sliced.search_sorted(22, SearchSortedSide::Left).unwrap(),
770            SearchResult::NotFound(2)
771        );
772    }
773
774    #[test]
775    fn search_right() {
776        let patches = Patches::new(
777            6,
778            0,
779            buffer![0u8, 1, 4, 5].into_array(),
780            buffer![-128i8, -98, 8, 50].into_array(),
781        );
782
783        assert_eq!(
784            patches.search_sorted(-98, SearchSortedSide::Right).unwrap(),
785            SearchResult::Found(2)
786        );
787        assert_eq!(
788            patches.search_sorted(50, SearchSortedSide::Right).unwrap(),
789            SearchResult::Found(6),
790        );
791        assert_eq!(
792            patches.search_sorted(7, SearchSortedSide::Right).unwrap(),
793            SearchResult::NotFound(2),
794        );
795        assert_eq!(
796            patches.search_sorted(51, SearchSortedSide::Right).unwrap(),
797            SearchResult::NotFound(6)
798        );
799    }
800
801    #[test]
802    fn search_left() {
803        let patches = Patches::new(
804            20,
805            0,
806            buffer![0u64, 1, 17, 18, 19].into_array(),
807            buffer![11i32, 22, 33, 44, 55].into_array(),
808        );
809        assert_eq!(
810            patches.search_sorted(30, SearchSortedSide::Left).unwrap(),
811            SearchResult::NotFound(2)
812        );
813        assert_eq!(
814            patches.search_sorted(54, SearchSortedSide::Left).unwrap(),
815            SearchResult::NotFound(19)
816        );
817    }
818
819    #[rstest]
820    fn take_with_nulls(patches: Patches) {
821        let taken = patches
822            .take(
823                &PrimitiveArray::new(buffer![9, 0], Validity::from_iter(vec![true, false]))
824                    .into_array(),
825            )
826            .unwrap()
827            .unwrap();
828        let primitive_values = taken.values().to_primitive().unwrap();
829        assert_eq!(taken.array_len(), 2);
830        assert_eq!(primitive_values.as_slice::<i32>(), [44]);
831        assert_eq!(
832            primitive_values.validity_mask().unwrap(),
833            Mask::from_iter(vec![true])
834        );
835    }
836
837    #[test]
838    fn test_slice() {
839        let values = buffer![15_u32, 135, 13531, 42].into_array();
840        let indices = buffer![10_u64, 11, 50, 100].into_array();
841
842        let patches = Patches::new(101, 0, indices, values);
843
844        let sliced = patches.slice(15, 100).unwrap();
845        assert_eq!(sliced.array_len(), 100 - 15);
846        let primitive = sliced.values().to_primitive().unwrap();
847
848        assert_eq!(primitive.as_slice::<u32>(), &[13531]);
849    }
850
851    #[test]
852    fn doubly_sliced() {
853        let values = buffer![15_u32, 135, 13531, 42].into_array();
854        let indices = buffer![10_u64, 11, 50, 100].into_array();
855
856        let patches = Patches::new(101, 0, indices, values);
857
858        let sliced = patches.slice(15, 100).unwrap();
859        assert_eq!(sliced.array_len(), 100 - 15);
860        let primitive = sliced.values().to_primitive().unwrap();
861
862        assert_eq!(primitive.as_slice::<u32>(), &[13531]);
863
864        let doubly_sliced = sliced.slice(35, 36).unwrap();
865        let primitive_doubly_sliced = doubly_sliced.values().to_primitive().unwrap();
866
867        assert_eq!(primitive_doubly_sliced.as_slice::<u32>(), &[13531]);
868    }
869
870    #[test]
871    fn test_mask_all_true() {
872        let patches = Patches::new(
873            10,
874            0,
875            buffer![2u64, 5, 8].into_array(),
876            buffer![100i32, 200, 300].into_array(),
877        );
878
879        let mask = Mask::new_true(10);
880        let masked = patches.mask(&mask).unwrap();
881        assert!(masked.is_none());
882    }
883
884    #[test]
885    fn test_mask_all_false() {
886        let patches = Patches::new(
887            10,
888            0,
889            buffer![2u64, 5, 8].into_array(),
890            buffer![100i32, 200, 300].into_array(),
891        );
892
893        let mask = Mask::new_false(10);
894        let masked = patches.mask(&mask).unwrap().unwrap();
895
896        // No patch values should be masked
897        let masked_values = masked.values().to_primitive().unwrap();
898        assert_eq!(masked_values.as_slice::<i32>(), &[100, 200, 300]);
899        assert!(masked_values.is_valid(0).unwrap());
900        assert!(masked_values.is_valid(1).unwrap());
901        assert!(masked_values.is_valid(2).unwrap());
902
903        // Indices should remain unchanged
904        let indices = masked.indices().to_primitive().unwrap();
905        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
906    }
907
908    #[test]
909    fn test_mask_partial() {
910        let patches = Patches::new(
911            10,
912            0,
913            buffer![2u64, 5, 8].into_array(),
914            buffer![100i32, 200, 300].into_array(),
915        );
916
917        // Mask that removes patches at indices 2 and 8 (but not 5)
918        let mask = Mask::from_iter([
919            false, false, true, false, false, false, false, false, true, false,
920        ]);
921        let masked = patches.mask(&mask).unwrap().unwrap();
922
923        // Only the patch at index 5 should remain
924        let masked_values = masked.values().to_primitive().unwrap();
925        assert_eq!(masked_values.len(), 1);
926        assert_eq!(masked_values.as_slice::<i32>(), &[200]);
927
928        // Only index 5 should remain
929        let indices = masked.indices().to_primitive().unwrap();
930        assert_eq!(indices.as_slice::<u64>(), &[5]);
931    }
932
933    #[test]
934    fn test_mask_with_offset() {
935        let patches = Patches::new(
936            10,
937            5,                                  // offset
938            buffer![7u64, 10, 13].into_array(), // actual indices are 2, 5, 8
939            buffer![100i32, 200, 300].into_array(),
940        );
941
942        // Mask that sets actual index 2 to null
943        let mask = Mask::from_iter([
944            false, false, true, false, false, false, false, false, false, false,
945        ]);
946
947        let masked = patches.mask(&mask).unwrap().unwrap();
948        assert_eq!(masked.array_len(), 10);
949        assert_eq!(masked.offset(), 5);
950        let indices = masked.indices().to_primitive().unwrap();
951        assert_eq!(indices.as_slice::<u64>(), &[10, 13]);
952        let masked_values = masked.values().to_primitive().unwrap();
953        assert_eq!(masked_values.as_slice::<i32>(), &[200, 300]);
954    }
955
956    #[test]
957    fn test_mask_nullable_values() {
958        let patches = Patches::new(
959            10,
960            0,
961            buffer![2u64, 5, 8].into_array(),
962            PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
963        );
964
965        // Test masking removes patch at index 2
966        let mask = Mask::from_iter([
967            false, false, true, false, false, false, false, false, false, false,
968        ]);
969        let masked = patches.mask(&mask).unwrap().unwrap();
970
971        // Patches at indices 5 and 8 should remain
972        let indices = masked.indices().to_primitive().unwrap();
973        assert_eq!(indices.as_slice::<u64>(), &[5, 8]);
974
975        // Values should be the null and 300
976        let masked_values = masked.values().to_primitive().unwrap();
977        assert_eq!(masked_values.len(), 2);
978        assert!(!masked_values.is_valid(0).unwrap()); // the null value at index 5
979        assert!(masked_values.is_valid(1).unwrap()); // the 300 value at index 8
980        assert_eq!(i32::try_from(&masked_values.scalar_at(1)).unwrap(), 300i32);
981    }
982
983    #[test]
984    fn test_filter_keep_all() {
985        let patches = Patches::new(
986            10,
987            0,
988            buffer![2u64, 5, 8].into_array(),
989            buffer![100i32, 200, 300].into_array(),
990        );
991
992        // Keep all indices (mask with indices 0-9)
993        let mask = Mask::from_indices(10, (0..10).collect());
994        let filtered = patches.filter(&mask).unwrap().unwrap();
995
996        let indices = filtered.indices().to_primitive().unwrap();
997        let values = filtered.values().to_primitive().unwrap();
998        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
999        assert_eq!(values.as_slice::<i32>(), &[100, 200, 300]);
1000    }
1001
1002    #[test]
1003    fn test_filter_none() {
1004        let patches = Patches::new(
1005            10,
1006            0,
1007            buffer![2u64, 5, 8].into_array(),
1008            buffer![100i32, 200, 300].into_array(),
1009        );
1010
1011        // Filter out all (empty mask means keep nothing)
1012        let mask = Mask::from_indices(10, vec![]);
1013        let filtered = patches.filter(&mask).unwrap();
1014        assert!(filtered.is_none());
1015    }
1016
1017    #[test]
1018    fn test_filter_with_indices() {
1019        let patches = Patches::new(
1020            10,
1021            0,
1022            buffer![2u64, 5, 8].into_array(),
1023            buffer![100i32, 200, 300].into_array(),
1024        );
1025
1026        // Keep indices 2, 5, 9 (so patches at 2 and 5 remain)
1027        let mask = Mask::from_indices(10, vec![2, 5, 9]);
1028        let filtered = patches.filter(&mask).unwrap().unwrap();
1029
1030        let indices = filtered.indices().to_primitive().unwrap();
1031        let values = filtered.values().to_primitive().unwrap();
1032        assert_eq!(indices.as_slice::<u64>(), &[0, 1]); // Adjusted indices
1033        assert_eq!(values.as_slice::<i32>(), &[100, 200]);
1034    }
1035
1036    #[test]
1037    fn test_slice_full_range() {
1038        let patches = Patches::new(
1039            10,
1040            0,
1041            buffer![2u64, 5, 8].into_array(),
1042            buffer![100i32, 200, 300].into_array(),
1043        );
1044
1045        let sliced = patches.slice(0, 10).unwrap();
1046
1047        let indices = sliced.indices().to_primitive().unwrap();
1048        let values = sliced.values().to_primitive().unwrap();
1049        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
1050        assert_eq!(values.as_slice::<i32>(), &[100, 200, 300]);
1051    }
1052
1053    #[test]
1054    fn test_slice_partial() {
1055        let patches = Patches::new(
1056            10,
1057            0,
1058            buffer![2u64, 5, 8].into_array(),
1059            buffer![100i32, 200, 300].into_array(),
1060        );
1061
1062        // Slice from 3 to 8 (includes patch at 5)
1063        let sliced = patches.slice(3, 8).unwrap();
1064
1065        let indices = sliced.indices().to_primitive().unwrap();
1066        let values = sliced.values().to_primitive().unwrap();
1067        assert_eq!(indices.as_slice::<u64>(), &[5]); // Index stays the same
1068        assert_eq!(values.as_slice::<i32>(), &[200]);
1069        assert_eq!(sliced.array_len(), 5); // 8 - 3 = 5
1070        assert_eq!(sliced.offset(), 3); // New offset
1071    }
1072
1073    #[test]
1074    fn test_slice_no_patches() {
1075        let patches = Patches::new(
1076            10,
1077            0,
1078            buffer![2u64, 5, 8].into_array(),
1079            buffer![100i32, 200, 300].into_array(),
1080        );
1081
1082        // Slice from 6 to 7 (no patches in this range)
1083        let sliced = patches.slice(6, 7);
1084        assert!(sliced.is_none());
1085    }
1086
1087    #[test]
1088    fn test_slice_with_offset() {
1089        let patches = Patches::new(
1090            10,
1091            5,                                  // offset
1092            buffer![7u64, 10, 13].into_array(), // actual indices are 2, 5, 8
1093            buffer![100i32, 200, 300].into_array(),
1094        );
1095
1096        // Slice from 3 to 8 (includes patch at actual index 5)
1097        let sliced = patches.slice(3, 8).unwrap();
1098
1099        let indices = sliced.indices().to_primitive().unwrap();
1100        let values = sliced.values().to_primitive().unwrap();
1101        assert_eq!(indices.as_slice::<u64>(), &[10]); // Index stays the same (offset + 5 = 10)
1102        assert_eq!(values.as_slice::<i32>(), &[200]);
1103        assert_eq!(sliced.offset(), 8); // New offset = 5 + 3
1104    }
1105
1106    #[test]
1107    fn test_patch_values() {
1108        let patches = Patches::new(
1109            10,
1110            0,
1111            buffer![2u64, 5, 8].into_array(),
1112            buffer![100i32, 200, 300].into_array(),
1113        );
1114
1115        let values = patches.values().to_primitive().unwrap();
1116        assert_eq!(i32::try_from(&values.scalar_at(0)).unwrap(), 100i32);
1117        assert_eq!(i32::try_from(&values.scalar_at(1)).unwrap(), 200i32);
1118        assert_eq!(i32::try_from(&values.scalar_at(2)).unwrap(), 300i32);
1119    }
1120
1121    #[test]
1122    fn test_indices_range() {
1123        let patches = Patches::new(
1124            10,
1125            0,
1126            buffer![2u64, 5, 8].into_array(),
1127            buffer![100i32, 200, 300].into_array(),
1128        );
1129
1130        assert_eq!(patches.min_index(), 2);
1131        assert_eq!(patches.max_index(), 8);
1132    }
1133
1134    #[test]
1135    fn test_search_index() {
1136        let patches = Patches::new(
1137            10,
1138            0,
1139            buffer![2u64, 5, 8].into_array(),
1140            buffer![100i32, 200, 300].into_array(),
1141        );
1142
1143        // Search for exact indices
1144        assert_eq!(patches.search_index(2), SearchResult::Found(0));
1145        assert_eq!(patches.search_index(5), SearchResult::Found(1));
1146        assert_eq!(patches.search_index(8), SearchResult::Found(2));
1147
1148        // Search for non-patch indices
1149        assert_eq!(patches.search_index(0), SearchResult::NotFound(0));
1150        assert_eq!(patches.search_index(3), SearchResult::NotFound(1));
1151        assert_eq!(patches.search_index(6), SearchResult::NotFound(2));
1152        assert_eq!(patches.search_index(9), SearchResult::NotFound(3));
1153    }
1154
1155    #[test]
1156    fn test_mask_boundary_patches() {
1157        // Test masking patches at array boundaries
1158        let patches = Patches::new(
1159            10,
1160            0,
1161            buffer![0u64, 9].into_array(),
1162            buffer![100i32, 200].into_array(),
1163        );
1164
1165        let mask = Mask::from_iter([
1166            true, false, false, false, false, false, false, false, false, false,
1167        ]);
1168        let masked = patches.mask(&mask).unwrap();
1169        assert!(masked.is_some());
1170        let masked = masked.unwrap();
1171        let indices = masked.indices().to_primitive().unwrap();
1172        assert_eq!(indices.as_slice::<u64>(), &[9]);
1173        let values = masked.values().to_primitive().unwrap();
1174        assert_eq!(values.as_slice::<i32>(), &[200]);
1175    }
1176
1177    #[test]
1178    fn test_mask_all_patches_removed() {
1179        // Test when all patches are masked out
1180        let patches = Patches::new(
1181            10,
1182            0,
1183            buffer![2u64, 5, 8].into_array(),
1184            buffer![100i32, 200, 300].into_array(),
1185        );
1186
1187        // Mask that removes all patches
1188        let mask = Mask::from_iter([
1189            false, false, true, false, false, true, false, false, true, false,
1190        ]);
1191        let masked = patches.mask(&mask).unwrap();
1192        assert!(masked.is_none());
1193    }
1194
1195    #[test]
1196    fn test_mask_no_patches_removed() {
1197        // Test when no patches are masked
1198        let patches = Patches::new(
1199            10,
1200            0,
1201            buffer![2u64, 5, 8].into_array(),
1202            buffer![100i32, 200, 300].into_array(),
1203        );
1204
1205        // Mask that doesn't affect any patches
1206        let mask = Mask::from_iter([
1207            true, false, false, true, false, false, true, false, false, true,
1208        ]);
1209        let masked = patches.mask(&mask).unwrap().unwrap();
1210
1211        let indices = masked.indices().to_primitive().unwrap();
1212        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
1213        let values = masked.values().to_primitive().unwrap();
1214        assert_eq!(values.as_slice::<i32>(), &[100, 200, 300]);
1215    }
1216
1217    #[test]
1218    fn test_mask_single_patch() {
1219        // Test with a single patch
1220        let patches = Patches::new(
1221            5,
1222            0,
1223            buffer![2u64].into_array(),
1224            buffer![42i32].into_array(),
1225        );
1226
1227        // Mask that removes the single patch
1228        let mask = Mask::from_iter([false, false, true, false, false]);
1229        let masked = patches.mask(&mask).unwrap();
1230        assert!(masked.is_none());
1231
1232        // Mask that keeps the single patch
1233        let mask = Mask::from_iter([true, false, false, true, false]);
1234        let masked = patches.mask(&mask).unwrap().unwrap();
1235        let indices = masked.indices().to_primitive().unwrap();
1236        assert_eq!(indices.as_slice::<u64>(), &[2]);
1237    }
1238
1239    #[test]
1240    fn test_mask_contiguous_patches() {
1241        // Test with contiguous patches
1242        let patches = Patches::new(
1243            10,
1244            0,
1245            buffer![3u64, 4, 5, 6].into_array(),
1246            buffer![100i32, 200, 300, 400].into_array(),
1247        );
1248
1249        // Mask that removes middle patches
1250        let mask = Mask::from_iter([
1251            false, false, false, false, true, true, false, false, false, false,
1252        ]);
1253        let masked = patches.mask(&mask).unwrap().unwrap();
1254
1255        let indices = masked.indices().to_primitive().unwrap();
1256        assert_eq!(indices.as_slice::<u64>(), &[3, 6]);
1257        let values = masked.values().to_primitive().unwrap();
1258        assert_eq!(values.as_slice::<i32>(), &[100, 400]);
1259    }
1260
1261    #[test]
1262    fn test_mask_with_large_offset() {
1263        // Test with a large offset that shifts all indices
1264        let patches = Patches::new(
1265            20,
1266            15,
1267            buffer![16u64, 17, 19].into_array(), // actual indices are 1, 2, 4
1268            buffer![100i32, 200, 300].into_array(),
1269        );
1270
1271        // Mask that removes the patch at actual index 2
1272        let mask = Mask::from_iter([
1273            false, false, true, false, false, false, false, false, false, false, false, false,
1274            false, false, false, false, false, false, false, false,
1275        ]);
1276        let masked = patches.mask(&mask).unwrap().unwrap();
1277
1278        let indices = masked.indices().to_primitive().unwrap();
1279        assert_eq!(indices.as_slice::<u64>(), &[16, 19]);
1280        let values = masked.values().to_primitive().unwrap();
1281        assert_eq!(values.as_slice::<i32>(), &[100, 300]);
1282    }
1283
1284    #[test]
1285    #[should_panic(expected = "Filter mask length 5 does not match array length 10")]
1286    fn test_mask_wrong_length() {
1287        let patches = Patches::new(
1288            10,
1289            0,
1290            buffer![2u64, 5, 8].into_array(),
1291            buffer![100i32, 200, 300].into_array(),
1292        );
1293
1294        // Mask with wrong length
1295        let mask = Mask::from_iter([false, false, true, false, false]);
1296        let _ = patches.mask(&mask).unwrap();
1297    }
1298}