vortex_array/
patches.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::fmt::Debug;
6use std::hash::Hash;
7
8use arrow_buffer::BooleanBuffer;
9use itertools::Itertools as _;
10use num_traits::{NumCast, ToPrimitive};
11use serde::{Deserialize, Serialize};
12use vortex_buffer::BufferMut;
13use vortex_dtype::Nullability::NonNullable;
14use vortex_dtype::{
15    DType, NativePType, PType, match_each_integer_ptype, match_each_unsigned_integer_ptype,
16};
17use vortex_error::{
18    VortexError, VortexExpect, VortexResult, VortexUnwrap, vortex_bail, vortex_err,
19};
20use vortex_mask::{AllOr, Mask};
21use vortex_scalar::{PValue, Scalar};
22use vortex_utils::aliases::hash_map::HashMap;
23
24use crate::arrays::PrimitiveArray;
25use crate::compute::{cast, filter, take};
26use crate::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
27use crate::vtable::ValidityHelper;
28use crate::{Array, ArrayRef, IntoArray, ToCanonical};
29
30#[derive(Copy, Clone, Serialize, Deserialize, prost::Message)]
31pub struct PatchesMetadata {
32    #[prost(uint64, tag = "1")]
33    len: u64,
34    #[prost(uint64, tag = "2")]
35    offset: u64,
36    #[prost(enumeration = "PType", tag = "3")]
37    indices_ptype: i32,
38}
39
40impl PatchesMetadata {
41    pub fn new(len: usize, offset: usize, indices_ptype: PType) -> Self {
42        Self {
43            len: len as u64,
44            offset: offset as u64,
45            indices_ptype: indices_ptype as i32,
46        }
47    }
48
49    #[inline]
50    pub fn len(&self) -> usize {
51        usize::try_from(self.len).vortex_expect("len is a valid usize")
52    }
53
54    #[inline]
55    pub fn is_empty(&self) -> bool {
56        self.len == 0
57    }
58
59    #[inline]
60    pub fn offset(&self) -> usize {
61        usize::try_from(self.offset).vortex_expect("offset is a valid usize")
62    }
63
64    #[inline]
65    pub fn indices_dtype(&self) -> DType {
66        assert!(
67            self.indices_ptype().is_unsigned_int(),
68            "Patch indices must be unsigned integers"
69        );
70        DType::Primitive(self.indices_ptype(), NonNullable)
71    }
72}
73
74/// A helper for working with patched arrays.
75#[derive(Debug, Clone)]
76pub struct Patches {
77    array_len: usize,
78    offset: usize,
79    indices: ArrayRef,
80    values: ArrayRef,
81}
82
83impl Patches {
84    pub fn new(array_len: usize, offset: usize, indices: ArrayRef, values: ArrayRef) -> Self {
85        assert_eq!(
86            indices.len(),
87            values.len(),
88            "Patch indices and values must have the same length"
89        );
90        assert!(
91            indices.dtype().is_unsigned_int() && !indices.dtype().is_nullable(),
92            "Patch indices must be non-nullable unsigned integers, got {:?}",
93            indices.dtype()
94        );
95        assert!(
96            indices.len() <= array_len,
97            "Patch indices must be shorter than the array length"
98        );
99        assert!(!indices.is_empty(), "Patch indices must not be empty");
100        let max = usize::try_from(
101            &indices
102                .scalar_at(indices.len() - 1)
103                .vortex_expect("indices are not empty"),
104        )
105        .vortex_expect("indices must be a number");
106        assert!(
107            max - offset < array_len,
108            "Patch indices {max:?}, offset {offset} are longer than the array length {array_len}"
109        );
110        Self::new_unchecked(array_len, offset, indices, values)
111    }
112
113    /// Construct new patches without validating any of the arguments
114    ///
115    /// # Safety
116    ///
117    /// Users have to assert that
118    /// * Indices and values have the same length
119    /// * Indices is an unsigned integer type
120    /// * Indices must be sorted
121    /// * Last value in indices is smaller than array_len
122    pub fn new_unchecked(
123        array_len: usize,
124        offset: usize,
125        indices: ArrayRef,
126        values: ArrayRef,
127    ) -> Self {
128        Self {
129            array_len,
130            offset,
131            indices,
132            values,
133        }
134    }
135
136    pub fn array_len(&self) -> usize {
137        self.array_len
138    }
139
140    pub fn num_patches(&self) -> usize {
141        self.indices.len()
142    }
143
144    pub fn dtype(&self) -> &DType {
145        self.values.dtype()
146    }
147
148    pub fn indices(&self) -> &ArrayRef {
149        &self.indices
150    }
151
152    pub fn into_indices(self) -> ArrayRef {
153        self.indices
154    }
155
156    pub fn indices_mut(&mut self) -> &mut ArrayRef {
157        &mut self.indices
158    }
159
160    pub fn values(&self) -> &ArrayRef {
161        &self.values
162    }
163
164    pub fn into_values(self) -> ArrayRef {
165        self.values
166    }
167
168    pub fn values_mut(&mut self) -> &mut ArrayRef {
169        &mut self.values
170    }
171
172    pub fn offset(&self) -> usize {
173        self.offset
174    }
175
176    pub fn indices_ptype(&self) -> PType {
177        PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
178    }
179
180    pub fn to_metadata(&self, len: usize, dtype: &DType) -> VortexResult<PatchesMetadata> {
181        if self.indices.len() > len {
182            vortex_bail!(
183                "Patch indices {} are longer than the array length {}",
184                self.indices.len(),
185                len
186            );
187        }
188        if self.values.dtype() != dtype {
189            vortex_bail!(
190                "Patch values dtype {} does not match array dtype {}",
191                self.values.dtype(),
192                dtype
193            );
194        }
195        Ok(PatchesMetadata {
196            len: self.indices.len() as u64,
197            offset: self.offset as u64,
198            indices_ptype: PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
199                as i32,
200        })
201    }
202
203    pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
204        Ok(Self::new_unchecked(
205            self.array_len,
206            self.offset,
207            self.indices,
208            cast(&self.values, values_dtype)?,
209        ))
210    }
211
212    /// Get the patched value at a given index if it exists.
213    pub fn get_patched(&self, index: usize) -> VortexResult<Option<Scalar>> {
214        if let Some(patch_idx) = self.search_index(index)?.to_found() {
215            self.values().scalar_at(patch_idx).map(Some)
216        } else {
217            Ok(None)
218        }
219    }
220
221    /// Return the insertion point of `index` in the [Self::indices].
222    pub fn search_index(&self, index: usize) -> VortexResult<SearchResult> {
223        Ok(self.indices.as_primitive_typed().search_sorted(
224            &PValue::U64((index + self.offset) as u64),
225            SearchSortedSide::Left,
226        ))
227    }
228
229    /// Return the search_sorted result for the given target re-mapped into the original indices.
230    pub fn search_sorted<T: Into<Scalar>>(
231        &self,
232        target: T,
233        side: SearchSortedSide,
234    ) -> VortexResult<SearchResult> {
235        let target = target.into();
236
237        let sr = if self.values().dtype().is_primitive() {
238            self.values()
239                .as_primitive_typed()
240                .search_sorted(&target.as_primitive().pvalue(), side)
241        } else {
242            self.values().search_sorted(&target, side)
243        };
244
245        let index_idx = sr.to_offsets_index(self.indices().len(), side);
246        let index = usize::try_from(&self.indices().scalar_at(index_idx)?)? - self.offset;
247        Ok(match sr {
248            // If we reached the end of patched values when searching then the result is one after the last patch index
249            SearchResult::Found(i) => SearchResult::Found(
250                if i == self.indices().len() || side == SearchSortedSide::Right {
251                    index + 1
252                } else {
253                    index
254                },
255            ),
256            // If the result is NotFound we should return index that's one after the nearest not found index for the corresponding value
257            SearchResult::NotFound(i) => {
258                SearchResult::NotFound(if i == 0 { index } else { index + 1 })
259            }
260        })
261    }
262
263    /// Returns the minimum patch index
264    pub fn min_index(&self) -> VortexResult<usize> {
265        Ok(usize::try_from(&self.indices().scalar_at(0)?)? - self.offset)
266    }
267
268    /// Returns the maximum patch index
269    pub fn max_index(&self) -> VortexResult<usize> {
270        Ok(usize::try_from(&self.indices().scalar_at(self.indices().len() - 1)?)? - self.offset)
271    }
272
273    /// Filter the patches by a mask, resulting in new patches for the filtered array.
274    pub fn filter(&self, mask: &Mask) -> VortexResult<Option<Self>> {
275        if mask.len() != self.array_len {
276            vortex_bail!(
277                "Filter mask length {} does not match array length {}",
278                mask.len(),
279                self.array_len
280            );
281        }
282
283        match mask.indices() {
284            AllOr::All => Ok(Some(self.clone())),
285            AllOr::None => Ok(None),
286            AllOr::Some(mask_indices) => {
287                let flat_indices = self.indices().to_primitive()?;
288                match_each_unsigned_integer_ptype!(flat_indices.ptype(), |I| {
289                    filter_patches_with_mask(
290                        flat_indices.as_slice::<I>(),
291                        self.offset(),
292                        self.values(),
293                        mask_indices,
294                    )
295                })
296            }
297        }
298    }
299
300    /// Mask the patches, REMOVING the patches where the mask is true.
301    /// Unlike filter, this preserves the patch indices.
302    /// Unlike mask on a single array, this does not set masked values to null.
303    pub fn mask(&self, mask: &Mask) -> VortexResult<Option<Self>> {
304        if mask.len() != self.array_len {
305            vortex_bail!(
306                "Filter mask length {} does not match array length {}",
307                mask.len(),
308                self.array_len
309            );
310        }
311
312        let filter_mask = match mask.boolean_buffer() {
313            AllOr::All => return Ok(None),
314            AllOr::None => return Ok(Some(self.clone())),
315            AllOr::Some(masked) => {
316                let patch_indices = self.indices().to_primitive()?;
317                match_each_unsigned_integer_ptype!(patch_indices.ptype(), |P| {
318                    let patch_indices = patch_indices.as_slice::<P>();
319                    Mask::from_buffer(BooleanBuffer::collect_bool(patch_indices.len(), |i| {
320                        #[allow(clippy::cast_possible_truncation)]
321                        let idx = (patch_indices[i] as usize) - self.offset;
322                        !masked.value(idx)
323                    }))
324                })
325            }
326        };
327
328        if filter_mask.all_false() {
329            return Ok(None);
330        }
331
332        Ok(Some(Self::new_unchecked(
333            self.array_len,
334            self.offset,
335            filter(&self.indices, &filter_mask)?,
336            filter(&self.values, &filter_mask)?,
337        )))
338    }
339
340    /// Slice the patches by a range of the patched array.
341    pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Option<Self>> {
342        let patch_start = self.search_index(start)?.to_index();
343        let patch_stop = self.search_index(stop)?.to_index();
344
345        if patch_start == patch_stop {
346            return Ok(None);
347        }
348
349        // Slice out the values and indices
350        let values = self.values().slice(patch_start, patch_stop)?;
351        let indices = self.indices().slice(patch_start, patch_stop)?;
352
353        Ok(Some(Self::new(
354            stop - start,
355            start + self.offset(),
356            indices,
357            values,
358        )))
359    }
360
361    // https://docs.google.com/spreadsheets/d/1D9vBZ1QJ6mwcIvV5wIL0hjGgVchcEnAyhvitqWu2ugU
362    const PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN: f64 = 5.0;
363
364    fn is_map_faster_than_search(&self, take_indices: &PrimitiveArray) -> bool {
365        (self.num_patches() as f64 / take_indices.len() as f64)
366            < Self::PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN
367    }
368
369    /// Take the indices from the patches
370    ///
371    /// Any nulls in take_indices are added to the resulting patches.
372    pub fn take_with_nulls(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
373        if take_indices.is_empty() {
374            return Ok(None);
375        }
376
377        let take_indices = take_indices.to_primitive()?;
378        if self.is_map_faster_than_search(&take_indices) {
379            self.take_map(take_indices, true)
380        } else {
381            self.take_search(take_indices, true)
382        }
383    }
384
385    /// Take the indices from the patches.
386    ///
387    /// Any nulls in take_indices are ignored.
388    pub fn take(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
389        if take_indices.is_empty() {
390            return Ok(None);
391        }
392
393        let take_indices = take_indices.to_primitive()?;
394        if self.is_map_faster_than_search(&take_indices) {
395            self.take_map(take_indices, false)
396        } else {
397            self.take_search(take_indices, false)
398        }
399    }
400
401    pub fn take_search(
402        &self,
403        take_indices: PrimitiveArray,
404        include_nulls: bool,
405    ) -> VortexResult<Option<Self>> {
406        let indices = self.indices.to_primitive()?;
407        let new_length = take_indices.len();
408
409        let Some((new_indices, values_indices)) =
410            match_each_unsigned_integer_ptype!(indices.ptype(), |Indices| {
411                match_each_integer_ptype!(take_indices.ptype(), |TakeIndices| {
412                    take_search::<_, TakeIndices>(
413                        indices.as_slice::<Indices>(),
414                        take_indices,
415                        self.offset(),
416                        include_nulls,
417                    )?
418                })
419            })
420        else {
421            return Ok(None);
422        };
423
424        Ok(Some(Self::new(
425            new_length,
426            0,
427            new_indices,
428            take(self.values(), &values_indices)?,
429        )))
430    }
431
432    pub fn take_map(
433        &self,
434        take_indices: PrimitiveArray,
435        include_nulls: bool,
436    ) -> VortexResult<Option<Self>> {
437        let indices = self.indices.to_primitive()?;
438        let new_length = take_indices.len();
439
440        let Some((new_sparse_indices, value_indices)) =
441            match_each_unsigned_integer_ptype!(indices.ptype(), |Indices| {
442                match_each_integer_ptype!(take_indices.ptype(), |TakeIndices| {
443                    take_map::<_, TakeIndices>(
444                        indices.as_slice::<Indices>(),
445                        take_indices,
446                        self.offset(),
447                        self.min_index()?,
448                        self.max_index()?,
449                        include_nulls,
450                    )?
451                })
452            })
453        else {
454            return Ok(None);
455        };
456
457        Ok(Some(Patches::new(
458            new_length,
459            0,
460            new_sparse_indices,
461            take(self.values(), &value_indices)?,
462        )))
463    }
464
465    pub fn map_values<F>(self, f: F) -> VortexResult<Self>
466    where
467        F: FnOnce(ArrayRef) -> VortexResult<ArrayRef>,
468    {
469        let values = f(self.values)?;
470        if self.indices.len() != values.len() {
471            vortex_bail!(
472                "map_values must preserve length: expected {} received {}",
473                self.indices.len(),
474                values.len()
475            )
476        }
477        Ok(Self::new(self.array_len, self.offset, self.indices, values))
478    }
479}
480
481fn take_search<I: NativePType + NumCast + PartialOrd, T: NativePType + NumCast>(
482    indices: &[I],
483    take_indices: PrimitiveArray,
484    indices_offset: usize,
485    include_nulls: bool,
486) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
487where
488    usize: TryFrom<T>,
489    VortexError: From<<usize as TryFrom<T>>::Error>,
490{
491    let take_indices_validity = take_indices.validity();
492    let indices_offset = I::from(indices_offset).vortex_expect("indices_offset out of range");
493
494    let (values_indices, new_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
495        .as_slice::<T>()
496        .iter()
497        .enumerate()
498        .filter_map(|(i, &v)| {
499            I::from(v)
500                .and_then(|v| {
501                    // If we have to take nulls the take index doesn't matter, make it 0 for consistency
502                    if include_nulls && take_indices_validity.is_null(i).vortex_unwrap() {
503                        Some(0)
504                    } else {
505                        indices
506                            .search_sorted(&(v + indices_offset), SearchSortedSide::Left)
507                            .to_found()
508                            .map(|patch_idx| patch_idx as u64)
509                    }
510                })
511                .map(|patch_idx| (patch_idx, i as u64))
512        })
513        .unzip();
514
515    if new_indices.is_empty() {
516        return Ok(None);
517    }
518
519    let new_indices = new_indices.into_array();
520    let values_validity = take_indices_validity.take(&new_indices)?;
521    Ok(Some((
522        new_indices,
523        PrimitiveArray::new(values_indices, values_validity).into_array(),
524    )))
525}
526
527fn take_map<I: NativePType + Hash + Eq + TryFrom<usize>, T: NativePType>(
528    indices: &[I],
529    take_indices: PrimitiveArray,
530    indices_offset: usize,
531    min_index: usize,
532    max_index: usize,
533    include_nulls: bool,
534) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
535where
536    usize: TryFrom<T>,
537    VortexError: From<<I as TryFrom<usize>>::Error>,
538{
539    let take_indices_validity = take_indices.validity();
540    let take_indices = take_indices.as_slice::<T>();
541    let offset_i = I::try_from(indices_offset)?;
542
543    let sparse_index_to_value_index: HashMap<I, usize> = indices
544        .iter()
545        .copied()
546        .map(|idx| idx - offset_i)
547        .enumerate()
548        .map(|(value_index, sparse_index)| (sparse_index, value_index))
549        .collect();
550
551    let (new_sparse_indices, value_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
552        .iter()
553        .copied()
554        .map(usize::try_from)
555        .process_results(|iter| {
556            iter.enumerate()
557                .filter_map(|(idx_in_take, ti)| {
558                    // If we have to take nulls the take index doesn't matter, make it 0 for consistency
559                    if include_nulls && take_indices_validity.is_null(idx_in_take).vortex_unwrap() {
560                        Some((idx_in_take as u64, 0))
561                    } else if ti < min_index || ti > max_index {
562                        None
563                    } else {
564                        sparse_index_to_value_index
565                            .get(
566                                &I::try_from(ti)
567                                    .vortex_expect("take index is between min and max index"),
568                            )
569                            .map(|value_index| (idx_in_take as u64, *value_index as u64))
570                    }
571                })
572                .unzip()
573        })
574        .map_err(|_| vortex_err!("Failed to convert index to usize"))?;
575
576    if new_sparse_indices.is_empty() {
577        return Ok(None);
578    }
579
580    let new_sparse_indices = new_sparse_indices.into_array();
581    let values_validity = take_indices_validity.take(&new_sparse_indices)?;
582    Ok(Some((
583        new_sparse_indices,
584        PrimitiveArray::new(value_indices, values_validity).into_array(),
585    )))
586}
587
588/// Filter patches with the provided mask (in flattened space).
589///
590/// The filter mask may contain indices that are non-patched. The return value of this function
591/// is a new set of `Patches` with the indices relative to the provided `mask` rank, and the
592/// patch values.
593fn filter_patches_with_mask<T: ToPrimitive + Copy + Ord>(
594    patch_indices: &[T],
595    offset: usize,
596    patch_values: &dyn Array,
597    mask_indices: &[usize],
598) -> VortexResult<Option<Patches>> {
599    let true_count = mask_indices.len();
600    let mut new_patch_indices = BufferMut::<u64>::with_capacity(true_count);
601    let mut new_mask_indices = Vec::with_capacity(true_count);
602
603    // Attempt to move the window by `STRIDE` elements on each iteration. This assumes that
604    // the patches are relatively sparse compared to the overall mask, and so many indices in the
605    // mask will end up being skipped.
606    const STRIDE: usize = 4;
607
608    let mut mask_idx = 0usize;
609    let mut true_idx = 0usize;
610
611    while mask_idx < patch_indices.len() && true_idx < true_count {
612        // NOTE: we are searching for overlaps between sorted, unaligned indices in `patch_indices`
613        //  and `mask_indices`. We assume that Patches are sparse relative to the global space of
614        //  the mask (which covers both patch and non-patch values of the parent array), and so to
615        //  quickly jump through regions with no overlap, we attempt to move our pointers by STRIDE
616        //  elements on each iteration. If we cannot rule out overlap due to min/max values, we
617        //  fallback to performing a two-way iterator merge.
618        if (mask_idx + STRIDE) < patch_indices.len() && (true_idx + STRIDE) < mask_indices.len() {
619            // Load a vector of each into our registers.
620            let left_min = patch_indices[mask_idx].to_usize().vortex_expect("left_min") - offset;
621            let left_max = patch_indices[mask_idx + STRIDE]
622                .to_usize()
623                .vortex_expect("left_max")
624                - offset;
625            let right_min = mask_indices[true_idx];
626            let right_max = mask_indices[true_idx + STRIDE];
627
628            if left_min > right_max {
629                // Advance right side
630                true_idx += STRIDE;
631                continue;
632            } else if right_min > left_max {
633                mask_idx += STRIDE;
634                continue;
635            } else {
636                // Fallthrough to direct comparison path.
637            }
638        }
639
640        // Two-way sorted iterator merge:
641
642        let left = patch_indices[mask_idx].to_usize().vortex_expect("left") - offset;
643        let right = mask_indices[true_idx];
644
645        match left.cmp(&right) {
646            Ordering::Less => {
647                mask_idx += 1;
648            }
649            Ordering::Greater => {
650                true_idx += 1;
651            }
652            Ordering::Equal => {
653                // Save the mask index as well as the positional index.
654                new_mask_indices.push(mask_idx);
655                new_patch_indices.push(true_idx as u64);
656
657                mask_idx += 1;
658                true_idx += 1;
659            }
660        }
661    }
662
663    if new_mask_indices.is_empty() {
664        return Ok(None);
665    }
666
667    let new_patch_indices = new_patch_indices.into_array();
668    let new_patch_values = filter(
669        patch_values,
670        &Mask::from_indices(patch_values.len(), new_mask_indices),
671    )?;
672
673    Ok(Some(Patches::new(
674        true_count,
675        0,
676        new_patch_indices,
677        new_patch_values,
678    )))
679}
680
681#[cfg(test)]
682mod test {
683    use rstest::{fixture, rstest};
684    use vortex_buffer::buffer;
685    use vortex_mask::Mask;
686
687    use crate::arrays::PrimitiveArray;
688    use crate::patches::Patches;
689    use crate::search_sorted::{SearchResult, SearchSortedSide};
690    use crate::validity::Validity;
691    use crate::{IntoArray, ToCanonical};
692
693    #[test]
694    fn test_filter() {
695        let patches = Patches::new(
696            100,
697            0,
698            buffer![10u32, 11, 20].into_array(),
699            buffer![100, 110, 200].into_array(),
700        );
701
702        let filtered = patches
703            .filter(&Mask::from_indices(100, vec![10, 20, 30]))
704            .unwrap()
705            .unwrap();
706
707        let indices = filtered.indices().to_primitive().unwrap();
708        let values = filtered.values().to_primitive().unwrap();
709        assert_eq!(indices.as_slice::<u64>(), &[0, 1]);
710        assert_eq!(values.as_slice::<i32>(), &[100, 200]);
711    }
712
713    #[fixture]
714    fn patches() -> Patches {
715        Patches::new(
716            20,
717            0,
718            buffer![2u64, 9, 15].into_array(),
719            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
720        )
721    }
722
723    #[rstest]
724    fn search_larger_than(patches: Patches) {
725        let res = patches.search_sorted(66, SearchSortedSide::Left).unwrap();
726        assert_eq!(res, SearchResult::NotFound(16));
727    }
728
729    #[rstest]
730    fn search_less_than(patches: Patches) {
731        let res = patches.search_sorted(22, SearchSortedSide::Left).unwrap();
732        assert_eq!(res, SearchResult::NotFound(2));
733    }
734
735    #[rstest]
736    fn search_found(patches: Patches) {
737        let res = patches.search_sorted(44, SearchSortedSide::Left).unwrap();
738        assert_eq!(res, SearchResult::Found(9));
739    }
740
741    #[rstest]
742    fn search_not_found_right(patches: Patches) {
743        let res = patches.search_sorted(56, SearchSortedSide::Right).unwrap();
744        assert_eq!(res, SearchResult::NotFound(16));
745    }
746
747    #[rstest]
748    fn search_sliced(patches: Patches) {
749        let sliced = patches.slice(7, 20).unwrap().unwrap();
750        assert_eq!(
751            sliced.search_sorted(22, SearchSortedSide::Left).unwrap(),
752            SearchResult::NotFound(2)
753        );
754    }
755
756    #[test]
757    fn search_right() {
758        let patches = Patches::new(
759            6,
760            0,
761            buffer![0u8, 1, 4, 5].into_array(),
762            buffer![-128i8, -98, 8, 50].into_array(),
763        );
764
765        assert_eq!(
766            patches.search_sorted(-98, SearchSortedSide::Right).unwrap(),
767            SearchResult::Found(2)
768        );
769        assert_eq!(
770            patches.search_sorted(50, SearchSortedSide::Right).unwrap(),
771            SearchResult::Found(6),
772        );
773        assert_eq!(
774            patches.search_sorted(7, SearchSortedSide::Right).unwrap(),
775            SearchResult::NotFound(2),
776        );
777        assert_eq!(
778            patches.search_sorted(51, SearchSortedSide::Right).unwrap(),
779            SearchResult::NotFound(6)
780        );
781    }
782
783    #[test]
784    fn search_left() {
785        let patches = Patches::new(
786            20,
787            0,
788            buffer![0u64, 1, 17, 18, 19].into_array(),
789            buffer![11i32, 22, 33, 44, 55].into_array(),
790        );
791        assert_eq!(
792            patches.search_sorted(30, SearchSortedSide::Left).unwrap(),
793            SearchResult::NotFound(2)
794        );
795        assert_eq!(
796            patches.search_sorted(54, SearchSortedSide::Left).unwrap(),
797            SearchResult::NotFound(19)
798        );
799    }
800
801    #[rstest]
802    fn take_with_nulls(patches: Patches) {
803        let taken = patches
804            .take(
805                &PrimitiveArray::new(buffer![9, 0], Validity::from_iter(vec![true, false]))
806                    .into_array(),
807            )
808            .unwrap()
809            .unwrap();
810        let primitive_values = taken.values().to_primitive().unwrap();
811        assert_eq!(taken.array_len(), 2);
812        assert_eq!(primitive_values.as_slice::<i32>(), [44]);
813        assert_eq!(
814            primitive_values.validity_mask().unwrap(),
815            Mask::from_iter(vec![true])
816        );
817    }
818
819    #[test]
820    fn test_slice() {
821        let values = buffer![15_u32, 135, 13531, 42].into_array();
822        let indices = buffer![10_u64, 11, 50, 100].into_array();
823
824        let patches = Patches::new(101, 0, indices, values);
825
826        let sliced = patches.slice(15, 100).unwrap().unwrap();
827        assert_eq!(sliced.array_len(), 100 - 15);
828        let primitive = sliced.values().to_primitive().unwrap();
829
830        assert_eq!(primitive.as_slice::<u32>(), &[13531]);
831    }
832
833    #[test]
834    fn doubly_sliced() {
835        let values = buffer![15_u32, 135, 13531, 42].into_array();
836        let indices = buffer![10_u64, 11, 50, 100].into_array();
837
838        let patches = Patches::new(101, 0, indices, values);
839
840        let sliced = patches.slice(15, 100).unwrap().unwrap();
841        assert_eq!(sliced.array_len(), 100 - 15);
842        let primitive = sliced.values().to_primitive().unwrap();
843
844        assert_eq!(primitive.as_slice::<u32>(), &[13531]);
845
846        let doubly_sliced = sliced.slice(35, 36).unwrap().unwrap();
847        let primitive_doubly_sliced = doubly_sliced.values().to_primitive().unwrap();
848
849        assert_eq!(primitive_doubly_sliced.as_slice::<u32>(), &[13531]);
850    }
851
852    #[test]
853    fn test_mask_all_true() {
854        let patches = Patches::new(
855            10,
856            0,
857            buffer![2u64, 5, 8].into_array(),
858            buffer![100i32, 200, 300].into_array(),
859        );
860
861        let mask = Mask::new_true(10);
862        let masked = patches.mask(&mask).unwrap();
863        assert!(masked.is_none());
864    }
865
866    #[test]
867    fn test_mask_all_false() {
868        let patches = Patches::new(
869            10,
870            0,
871            buffer![2u64, 5, 8].into_array(),
872            buffer![100i32, 200, 300].into_array(),
873        );
874
875        let mask = Mask::new_false(10);
876        let masked = patches.mask(&mask).unwrap().unwrap();
877
878        // No patch values should be masked
879        let masked_values = masked.values().to_primitive().unwrap();
880        assert_eq!(masked_values.as_slice::<i32>(), &[100, 200, 300]);
881        assert!(masked_values.is_valid(0).unwrap());
882        assert!(masked_values.is_valid(1).unwrap());
883        assert!(masked_values.is_valid(2).unwrap());
884
885        // Indices should remain unchanged
886        let indices = masked.indices().to_primitive().unwrap();
887        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
888    }
889
890    #[test]
891    fn test_mask_partial() {
892        let patches = Patches::new(
893            10,
894            0,
895            buffer![2u64, 5, 8].into_array(),
896            buffer![100i32, 200, 300].into_array(),
897        );
898
899        // Mask that removes patches at indices 2 and 8 (but not 5)
900        let mask = Mask::from_iter([
901            false, false, true, false, false, false, false, false, true, false,
902        ]);
903        let masked = patches.mask(&mask).unwrap().unwrap();
904
905        // Only the patch at index 5 should remain
906        let masked_values = masked.values().to_primitive().unwrap();
907        assert_eq!(masked_values.len(), 1);
908        assert_eq!(masked_values.as_slice::<i32>(), &[200]);
909
910        // Only index 5 should remain
911        let indices = masked.indices().to_primitive().unwrap();
912        assert_eq!(indices.as_slice::<u64>(), &[5]);
913    }
914
915    #[test]
916    fn test_mask_with_offset() {
917        let patches = Patches::new(
918            10,
919            5,                                  // offset
920            buffer![7u64, 10, 13].into_array(), // actual indices are 2, 5, 8
921            buffer![100i32, 200, 300].into_array(),
922        );
923
924        // Mask that sets actual index 2 to null
925        let mask = Mask::from_iter([
926            false, false, true, false, false, false, false, false, false, false,
927        ]);
928
929        let masked = patches.mask(&mask).unwrap().unwrap();
930        assert_eq!(masked.array_len(), 10);
931        assert_eq!(masked.offset(), 5);
932        let indices = masked.indices().to_primitive().unwrap();
933        assert_eq!(indices.as_slice::<u64>(), &[10, 13]);
934        let masked_values = masked.values().to_primitive().unwrap();
935        assert_eq!(masked_values.as_slice::<i32>(), &[200, 300]);
936    }
937
938    #[test]
939    fn test_mask_nullable_values() {
940        let patches = Patches::new(
941            10,
942            0,
943            buffer![2u64, 5, 8].into_array(),
944            PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
945        );
946
947        // Test masking removes patch at index 2
948        let mask = Mask::from_iter([
949            false, false, true, false, false, false, false, false, false, false,
950        ]);
951        let masked = patches.mask(&mask).unwrap().unwrap();
952
953        // Patches at indices 5 and 8 should remain
954        let indices = masked.indices().to_primitive().unwrap();
955        assert_eq!(indices.as_slice::<u64>(), &[5, 8]);
956
957        // Values should be the null and 300
958        let masked_values = masked.values().to_primitive().unwrap();
959        assert_eq!(masked_values.len(), 2);
960        assert!(!masked_values.is_valid(0).unwrap()); // the null value at index 5
961        assert!(masked_values.is_valid(1).unwrap()); // the 300 value at index 8
962        assert_eq!(
963            i32::try_from(&masked_values.scalar_at(1).unwrap()).unwrap(),
964            300i32
965        );
966    }
967
968    #[test]
969    fn test_filter_keep_all() {
970        let patches = Patches::new(
971            10,
972            0,
973            buffer![2u64, 5, 8].into_array(),
974            buffer![100i32, 200, 300].into_array(),
975        );
976
977        // Keep all indices (mask with indices 0-9)
978        let mask = Mask::from_indices(10, (0..10).collect());
979        let filtered = patches.filter(&mask).unwrap().unwrap();
980
981        let indices = filtered.indices().to_primitive().unwrap();
982        let values = filtered.values().to_primitive().unwrap();
983        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
984        assert_eq!(values.as_slice::<i32>(), &[100, 200, 300]);
985    }
986
987    #[test]
988    fn test_filter_none() {
989        let patches = Patches::new(
990            10,
991            0,
992            buffer![2u64, 5, 8].into_array(),
993            buffer![100i32, 200, 300].into_array(),
994        );
995
996        // Filter out all (empty mask means keep nothing)
997        let mask = Mask::from_indices(10, vec![]);
998        let filtered = patches.filter(&mask).unwrap();
999        assert!(filtered.is_none());
1000    }
1001
1002    #[test]
1003    fn test_filter_with_indices() {
1004        let patches = Patches::new(
1005            10,
1006            0,
1007            buffer![2u64, 5, 8].into_array(),
1008            buffer![100i32, 200, 300].into_array(),
1009        );
1010
1011        // Keep indices 2, 5, 9 (so patches at 2 and 5 remain)
1012        let mask = Mask::from_indices(10, vec![2, 5, 9]);
1013        let filtered = patches.filter(&mask).unwrap().unwrap();
1014
1015        let indices = filtered.indices().to_primitive().unwrap();
1016        let values = filtered.values().to_primitive().unwrap();
1017        assert_eq!(indices.as_slice::<u64>(), &[0, 1]); // Adjusted indices
1018        assert_eq!(values.as_slice::<i32>(), &[100, 200]);
1019    }
1020
1021    #[test]
1022    fn test_slice_full_range() {
1023        let patches = Patches::new(
1024            10,
1025            0,
1026            buffer![2u64, 5, 8].into_array(),
1027            buffer![100i32, 200, 300].into_array(),
1028        );
1029
1030        let sliced = patches.slice(0, 10).unwrap().unwrap();
1031
1032        let indices = sliced.indices().to_primitive().unwrap();
1033        let values = sliced.values().to_primitive().unwrap();
1034        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
1035        assert_eq!(values.as_slice::<i32>(), &[100, 200, 300]);
1036    }
1037
1038    #[test]
1039    fn test_slice_partial() {
1040        let patches = Patches::new(
1041            10,
1042            0,
1043            buffer![2u64, 5, 8].into_array(),
1044            buffer![100i32, 200, 300].into_array(),
1045        );
1046
1047        // Slice from 3 to 8 (includes patch at 5)
1048        let sliced = patches.slice(3, 8).unwrap().unwrap();
1049
1050        let indices = sliced.indices().to_primitive().unwrap();
1051        let values = sliced.values().to_primitive().unwrap();
1052        assert_eq!(indices.as_slice::<u64>(), &[5]); // Index stays the same
1053        assert_eq!(values.as_slice::<i32>(), &[200]);
1054        assert_eq!(sliced.array_len(), 5); // 8 - 3 = 5
1055        assert_eq!(sliced.offset(), 3); // New offset
1056    }
1057
1058    #[test]
1059    fn test_slice_no_patches() {
1060        let patches = Patches::new(
1061            10,
1062            0,
1063            buffer![2u64, 5, 8].into_array(),
1064            buffer![100i32, 200, 300].into_array(),
1065        );
1066
1067        // Slice from 6 to 7 (no patches in this range)
1068        let sliced = patches.slice(6, 7).unwrap();
1069        assert!(sliced.is_none());
1070    }
1071
1072    #[test]
1073    fn test_slice_with_offset() {
1074        let patches = Patches::new(
1075            10,
1076            5,                                  // offset
1077            buffer![7u64, 10, 13].into_array(), // actual indices are 2, 5, 8
1078            buffer![100i32, 200, 300].into_array(),
1079        );
1080
1081        // Slice from 3 to 8 (includes patch at actual index 5)
1082        let sliced = patches.slice(3, 8).unwrap().unwrap();
1083
1084        let indices = sliced.indices().to_primitive().unwrap();
1085        let values = sliced.values().to_primitive().unwrap();
1086        assert_eq!(indices.as_slice::<u64>(), &[10]); // Index stays the same (offset + 5 = 10)
1087        assert_eq!(values.as_slice::<i32>(), &[200]);
1088        assert_eq!(sliced.offset(), 8); // New offset = 5 + 3
1089    }
1090
1091    #[test]
1092    fn test_patch_values() {
1093        let patches = Patches::new(
1094            10,
1095            0,
1096            buffer![2u64, 5, 8].into_array(),
1097            buffer![100i32, 200, 300].into_array(),
1098        );
1099
1100        let values = patches.values().to_primitive().unwrap();
1101        assert_eq!(
1102            i32::try_from(&values.scalar_at(0).unwrap()).unwrap(),
1103            100i32
1104        );
1105        assert_eq!(
1106            i32::try_from(&values.scalar_at(1).unwrap()).unwrap(),
1107            200i32
1108        );
1109        assert_eq!(
1110            i32::try_from(&values.scalar_at(2).unwrap()).unwrap(),
1111            300i32
1112        );
1113    }
1114
1115    #[test]
1116    fn test_indices_range() {
1117        let patches = Patches::new(
1118            10,
1119            0,
1120            buffer![2u64, 5, 8].into_array(),
1121            buffer![100i32, 200, 300].into_array(),
1122        );
1123
1124        assert_eq!(patches.min_index().unwrap(), 2);
1125        assert_eq!(patches.max_index().unwrap(), 8);
1126    }
1127
1128    #[test]
1129    fn test_search_index() {
1130        let patches = Patches::new(
1131            10,
1132            0,
1133            buffer![2u64, 5, 8].into_array(),
1134            buffer![100i32, 200, 300].into_array(),
1135        );
1136
1137        // Search for exact indices
1138        assert_eq!(patches.search_index(2).unwrap(), SearchResult::Found(0));
1139        assert_eq!(patches.search_index(5).unwrap(), SearchResult::Found(1));
1140        assert_eq!(patches.search_index(8).unwrap(), SearchResult::Found(2));
1141
1142        // Search for non-patch indices
1143        assert_eq!(patches.search_index(0).unwrap(), SearchResult::NotFound(0));
1144        assert_eq!(patches.search_index(3).unwrap(), SearchResult::NotFound(1));
1145        assert_eq!(patches.search_index(6).unwrap(), SearchResult::NotFound(2));
1146        assert_eq!(patches.search_index(9).unwrap(), SearchResult::NotFound(3));
1147    }
1148
1149    #[test]
1150    fn test_mask_boundary_patches() {
1151        // Test masking patches at array boundaries
1152        let patches = Patches::new(
1153            10,
1154            0,
1155            buffer![0u64, 9].into_array(),
1156            buffer![100i32, 200].into_array(),
1157        );
1158
1159        let mask = Mask::from_iter([
1160            true, false, false, false, false, false, false, false, false, false,
1161        ]);
1162        let masked = patches.mask(&mask).unwrap();
1163        assert!(masked.is_some());
1164        let masked = masked.unwrap();
1165        let indices = masked.indices().to_primitive().unwrap();
1166        assert_eq!(indices.as_slice::<u64>(), &[9]);
1167        let values = masked.values().to_primitive().unwrap();
1168        assert_eq!(values.as_slice::<i32>(), &[200]);
1169    }
1170
1171    #[test]
1172    fn test_mask_all_patches_removed() {
1173        // Test when all patches are masked out
1174        let patches = Patches::new(
1175            10,
1176            0,
1177            buffer![2u64, 5, 8].into_array(),
1178            buffer![100i32, 200, 300].into_array(),
1179        );
1180
1181        // Mask that removes all patches
1182        let mask = Mask::from_iter([
1183            false, false, true, false, false, true, false, false, true, false,
1184        ]);
1185        let masked = patches.mask(&mask).unwrap();
1186        assert!(masked.is_none());
1187    }
1188
1189    #[test]
1190    fn test_mask_no_patches_removed() {
1191        // Test when no patches are masked
1192        let patches = Patches::new(
1193            10,
1194            0,
1195            buffer![2u64, 5, 8].into_array(),
1196            buffer![100i32, 200, 300].into_array(),
1197        );
1198
1199        // Mask that doesn't affect any patches
1200        let mask = Mask::from_iter([
1201            true, false, false, true, false, false, true, false, false, true,
1202        ]);
1203        let masked = patches.mask(&mask).unwrap().unwrap();
1204
1205        let indices = masked.indices().to_primitive().unwrap();
1206        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
1207        let values = masked.values().to_primitive().unwrap();
1208        assert_eq!(values.as_slice::<i32>(), &[100, 200, 300]);
1209    }
1210
1211    #[test]
1212    fn test_mask_single_patch() {
1213        // Test with a single patch
1214        let patches = Patches::new(
1215            5,
1216            0,
1217            buffer![2u64].into_array(),
1218            buffer![42i32].into_array(),
1219        );
1220
1221        // Mask that removes the single patch
1222        let mask = Mask::from_iter([false, false, true, false, false]);
1223        let masked = patches.mask(&mask).unwrap();
1224        assert!(masked.is_none());
1225
1226        // Mask that keeps the single patch
1227        let mask = Mask::from_iter([true, false, false, true, false]);
1228        let masked = patches.mask(&mask).unwrap().unwrap();
1229        let indices = masked.indices().to_primitive().unwrap();
1230        assert_eq!(indices.as_slice::<u64>(), &[2]);
1231    }
1232
1233    #[test]
1234    fn test_mask_contiguous_patches() {
1235        // Test with contiguous patches
1236        let patches = Patches::new(
1237            10,
1238            0,
1239            buffer![3u64, 4, 5, 6].into_array(),
1240            buffer![100i32, 200, 300, 400].into_array(),
1241        );
1242
1243        // Mask that removes middle patches
1244        let mask = Mask::from_iter([
1245            false, false, false, false, true, true, false, false, false, false,
1246        ]);
1247        let masked = patches.mask(&mask).unwrap().unwrap();
1248
1249        let indices = masked.indices().to_primitive().unwrap();
1250        assert_eq!(indices.as_slice::<u64>(), &[3, 6]);
1251        let values = masked.values().to_primitive().unwrap();
1252        assert_eq!(values.as_slice::<i32>(), &[100, 400]);
1253    }
1254
1255    #[test]
1256    fn test_mask_with_large_offset() {
1257        // Test with a large offset that shifts all indices
1258        let patches = Patches::new(
1259            20,
1260            15,
1261            buffer![16u64, 17, 19].into_array(), // actual indices are 1, 2, 4
1262            buffer![100i32, 200, 300].into_array(),
1263        );
1264
1265        // Mask that removes the patch at actual index 2
1266        let mask = Mask::from_iter([
1267            false, false, true, false, false, false, false, false, false, false, false, false,
1268            false, false, false, false, false, false, false, false,
1269        ]);
1270        let masked = patches.mask(&mask).unwrap().unwrap();
1271
1272        let indices = masked.indices().to_primitive().unwrap();
1273        assert_eq!(indices.as_slice::<u64>(), &[16, 19]);
1274        let values = masked.values().to_primitive().unwrap();
1275        assert_eq!(values.as_slice::<i32>(), &[100, 300]);
1276    }
1277
1278    #[test]
1279    #[should_panic(expected = "Filter mask length 5 does not match array length 10")]
1280    fn test_mask_wrong_length() {
1281        let patches = Patches::new(
1282            10,
1283            0,
1284            buffer![2u64, 5, 8].into_array(),
1285            buffer![100i32, 200, 300].into_array(),
1286        );
1287
1288        // Mask with wrong length
1289        let mask = Mask::from_iter([false, false, true, false, false]);
1290        let _ = patches.mask(&mask).unwrap();
1291    }
1292}