vortex_array/
patches.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::fmt::Debug;
6use std::hash::Hash;
7use std::ops::Range;
8
9use arrow_buffer::BooleanBuffer;
10use itertools::Itertools as _;
11use num_traits::{NumCast, ToPrimitive};
12use serde::{Deserialize, Serialize};
13use vortex_buffer::BufferMut;
14use vortex_dtype::Nullability::NonNullable;
15use vortex_dtype::{
16    DType, NativePType, PType, match_each_integer_ptype, match_each_unsigned_integer_ptype,
17};
18use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
19use vortex_mask::{AllOr, Mask};
20use vortex_scalar::{PValue, Scalar};
21use vortex_utils::aliases::hash_map::HashMap;
22
23use crate::arrays::PrimitiveArray;
24use crate::compute::{cast, filter, take};
25use crate::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
26use crate::vtable::ValidityHelper;
27use crate::{Array, ArrayRef, IntoArray, ToCanonical};
28
29#[derive(Copy, Clone, Serialize, Deserialize, prost::Message)]
30pub struct PatchesMetadata {
31    #[prost(uint64, tag = "1")]
32    len: u64,
33    #[prost(uint64, tag = "2")]
34    offset: u64,
35    #[prost(enumeration = "PType", tag = "3")]
36    indices_ptype: i32,
37}
38
39impl PatchesMetadata {
40    #[inline]
41    pub fn new(len: usize, offset: usize, indices_ptype: PType) -> Self {
42        Self {
43            len: len as u64,
44            offset: offset as u64,
45            indices_ptype: indices_ptype as i32,
46        }
47    }
48
49    #[inline]
50    pub fn len(&self) -> usize {
51        usize::try_from(self.len).vortex_expect("len is a valid usize")
52    }
53
54    #[inline]
55    pub fn is_empty(&self) -> bool {
56        self.len == 0
57    }
58
59    #[inline]
60    pub fn offset(&self) -> usize {
61        usize::try_from(self.offset).vortex_expect("offset is a valid usize")
62    }
63
64    #[inline]
65    pub fn indices_dtype(&self) -> DType {
66        assert!(
67            self.indices_ptype().is_unsigned_int(),
68            "Patch indices must be unsigned integers"
69        );
70        DType::Primitive(self.indices_ptype(), NonNullable)
71    }
72}
73
74/// A helper for working with patched arrays.
75#[derive(Debug, Clone)]
76pub struct Patches {
77    array_len: usize,
78    offset: usize,
79    indices: ArrayRef,
80    values: ArrayRef,
81}
82
83impl Patches {
84    pub fn new(array_len: usize, offset: usize, indices: ArrayRef, values: ArrayRef) -> Self {
85        assert_eq!(
86            indices.len(),
87            values.len(),
88            "Patch indices and values must have the same length"
89        );
90        assert!(
91            indices.dtype().is_unsigned_int() && !indices.dtype().is_nullable(),
92            "Patch indices must be non-nullable unsigned integers, got {:?}",
93            indices.dtype()
94        );
95        assert!(
96            indices.len() <= array_len,
97            "Patch indices must be shorter than the array length"
98        );
99        assert!(!indices.is_empty(), "Patch indices must not be empty");
100        let max = usize::try_from(&indices.scalar_at(indices.len() - 1))
101            .vortex_expect("indices must be a number");
102        assert!(
103            max - offset < array_len,
104            "Patch indices {max:?}, offset {offset} are longer than the array length {array_len}"
105        );
106
107        Self {
108            array_len,
109            offset,
110            indices,
111            values,
112        }
113    }
114
115    /// Construct new patches without validating any of the arguments
116    ///
117    /// # Safety
118    ///
119    /// Users have to assert that
120    /// * Indices and values have the same length
121    /// * Indices is an unsigned integer type
122    /// * Indices must be sorted
123    /// * Last value in indices is smaller than array_len
124    pub unsafe fn new_unchecked(
125        array_len: usize,
126        offset: usize,
127        indices: ArrayRef,
128        values: ArrayRef,
129    ) -> Self {
130        Self {
131            array_len,
132            offset,
133            indices,
134            values,
135        }
136    }
137
138    #[inline]
139    pub fn array_len(&self) -> usize {
140        self.array_len
141    }
142
143    #[inline]
144    pub fn num_patches(&self) -> usize {
145        self.indices.len()
146    }
147
148    #[inline]
149    pub fn dtype(&self) -> &DType {
150        self.values.dtype()
151    }
152
153    #[inline]
154    pub fn indices(&self) -> &ArrayRef {
155        &self.indices
156    }
157
158    #[inline]
159    pub fn into_indices(self) -> ArrayRef {
160        self.indices
161    }
162
163    #[inline]
164    pub fn indices_mut(&mut self) -> &mut ArrayRef {
165        &mut self.indices
166    }
167
168    #[inline]
169    pub fn values(&self) -> &ArrayRef {
170        &self.values
171    }
172
173    #[inline]
174    pub fn into_values(self) -> ArrayRef {
175        self.values
176    }
177
178    #[inline]
179    pub fn values_mut(&mut self) -> &mut ArrayRef {
180        &mut self.values
181    }
182
183    #[inline]
184    pub fn offset(&self) -> usize {
185        self.offset
186    }
187
188    #[inline]
189    pub fn indices_ptype(&self) -> PType {
190        PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
191    }
192
193    pub fn to_metadata(&self, len: usize, dtype: &DType) -> VortexResult<PatchesMetadata> {
194        if self.indices.len() > len {
195            vortex_bail!(
196                "Patch indices {} are longer than the array length {}",
197                self.indices.len(),
198                len
199            );
200        }
201        if self.values.dtype() != dtype {
202            vortex_bail!(
203                "Patch values dtype {} does not match array dtype {}",
204                self.values.dtype(),
205                dtype
206            );
207        }
208        Ok(PatchesMetadata {
209            len: self.indices.len() as u64,
210            offset: self.offset as u64,
211            indices_ptype: PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
212                as i32,
213        })
214    }
215
216    pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
217        // SAFETY: casting does not affect the relationship between the indices and values
218        unsafe {
219            Ok(Self::new_unchecked(
220                self.array_len,
221                self.offset,
222                self.indices,
223                cast(&self.values, values_dtype)?,
224            ))
225        }
226    }
227
228    /// Get the patched value at a given index if it exists.
229    pub fn get_patched(&self, index: usize) -> Option<Scalar> {
230        self.search_index(index)
231            .to_found()
232            .map(|patch_idx| self.values().scalar_at(patch_idx))
233    }
234
235    /// Return the insertion point of `index` in the [Self::indices].
236    pub fn search_index(&self, index: usize) -> SearchResult {
237        if self.indices.is_canonical() {
238            let primitive = self.indices.to_primitive();
239            match_each_integer_ptype!(primitive.ptype(), |T| {
240                let Ok(needle) = T::try_from(index + self.offset) else {
241                    // If the needle is not of type T, then it cannot possibly be in this array.
242                    //
243                    // The needle is a non-negative integer (a usize); therefore, it must be larger
244                    // than all values in this array.
245                    return SearchResult::NotFound(primitive.len());
246                };
247                return primitive
248                    .as_slice::<T>()
249                    .search_sorted(&needle, SearchSortedSide::Left);
250            });
251        }
252        self.indices.as_primitive_typed().search_sorted(
253            &PValue::U64((index + self.offset) as u64),
254            SearchSortedSide::Left,
255        )
256    }
257
258    /// Return the search_sorted result for the given target re-mapped into the original indices.
259    pub fn search_sorted<T: Into<Scalar>>(
260        &self,
261        target: T,
262        side: SearchSortedSide,
263    ) -> VortexResult<SearchResult> {
264        let target = target.into();
265
266        let sr = if self.values().dtype().is_primitive() {
267            self.values()
268                .as_primitive_typed()
269                .search_sorted(&target.as_primitive().pvalue(), side)
270        } else {
271            self.values().search_sorted(&target, side)
272        };
273
274        let index_idx = sr.to_offsets_index(self.indices().len(), side);
275        let index = usize::try_from(&self.indices().scalar_at(index_idx))? - self.offset;
276        Ok(match sr {
277            // If we reached the end of patched values when searching then the result is one after the last patch index
278            SearchResult::Found(i) => SearchResult::Found(
279                if i == self.indices().len() || side == SearchSortedSide::Right {
280                    index + 1
281                } else {
282                    index
283                },
284            ),
285            // If the result is NotFound we should return index that's one after the nearest not found index for the corresponding value
286            SearchResult::NotFound(i) => {
287                SearchResult::NotFound(if i == 0 { index } else { index + 1 })
288            }
289        })
290    }
291
292    /// Returns the minimum patch index
293    pub fn min_index(&self) -> usize {
294        let first = self
295            .indices
296            .scalar_at(0)
297            .as_primitive()
298            .as_::<usize>()
299            .vortex_expect("non-null");
300        first - self.offset
301    }
302
303    /// Returns the maximum patch index
304    pub fn max_index(&self) -> usize {
305        let last = self
306            .indices
307            .scalar_at(self.indices.len() - 1)
308            .as_primitive()
309            .as_::<usize>()
310            .vortex_expect("non-null");
311        last - self.offset
312    }
313
314    /// Filter the patches by a mask, resulting in new patches for the filtered array.
315    pub fn filter(&self, mask: &Mask) -> VortexResult<Option<Self>> {
316        if mask.len() != self.array_len {
317            vortex_bail!(
318                "Filter mask length {} does not match array length {}",
319                mask.len(),
320                self.array_len
321            );
322        }
323
324        match mask.indices() {
325            AllOr::All => Ok(Some(self.clone())),
326            AllOr::None => Ok(None),
327            AllOr::Some(mask_indices) => {
328                let flat_indices = self.indices().to_primitive();
329                match_each_unsigned_integer_ptype!(flat_indices.ptype(), |I| {
330                    filter_patches_with_mask(
331                        flat_indices.as_slice::<I>(),
332                        self.offset(),
333                        self.values(),
334                        mask_indices,
335                    )
336                })
337            }
338        }
339    }
340
341    /// Mask the patches, REMOVING the patches where the mask is true.
342    /// Unlike filter, this preserves the patch indices.
343    /// Unlike mask on a single array, this does not set masked values to null.
344    pub fn mask(&self, mask: &Mask) -> VortexResult<Option<Self>> {
345        if mask.len() != self.array_len {
346            vortex_bail!(
347                "Filter mask length {} does not match array length {}",
348                mask.len(),
349                self.array_len
350            );
351        }
352
353        let filter_mask = match mask.boolean_buffer() {
354            AllOr::All => return Ok(None),
355            AllOr::None => return Ok(Some(self.clone())),
356            AllOr::Some(masked) => {
357                let patch_indices = self.indices().to_primitive();
358                match_each_unsigned_integer_ptype!(patch_indices.ptype(), |P| {
359                    let patch_indices = patch_indices.as_slice::<P>();
360                    Mask::from_buffer(BooleanBuffer::collect_bool(patch_indices.len(), |i| {
361                        #[allow(clippy::cast_possible_truncation)]
362                        let idx = (patch_indices[i] as usize) - self.offset;
363                        !masked.value(idx)
364                    }))
365                })
366            }
367        };
368
369        if filter_mask.all_false() {
370            return Ok(None);
371        }
372
373        // SAFETY: filtering indices/values with same mask maintains their 1:1 relationship
374        unsafe {
375            Ok(Some(Self::new_unchecked(
376                self.array_len,
377                self.offset,
378                filter(&self.indices, &filter_mask)?,
379                filter(&self.values, &filter_mask)?,
380            )))
381        }
382    }
383
384    /// Slice the patches by a range of the patched array.
385    pub fn slice(&self, range: Range<usize>) -> Option<Self> {
386        if range.len() == 1 {
387            let patch_index = self.search_index(range.start).to_found()?;
388            let values = self.values.slice(patch_index..patch_index + 1);
389            let indices = self.indices.slice(patch_index..patch_index + 1);
390            return Some(Self::new(1, range.start + self.offset(), indices, values));
391        }
392
393        let patch_start = self.search_index(range.start).to_index();
394        let patch_stop = self.search_index(range.end).to_index();
395
396        if patch_start == patch_stop {
397            return None;
398        }
399
400        // Slice out the values and indices
401        let values = self.values().slice(patch_start..patch_stop);
402        let indices = self.indices().slice(patch_start..patch_stop);
403
404        Some(Self::new(
405            range.len(),
406            range.start + self.offset(),
407            indices,
408            values,
409        ))
410    }
411
412    // https://docs.google.com/spreadsheets/d/1D9vBZ1QJ6mwcIvV5wIL0hjGgVchcEnAyhvitqWu2ugU
413    const PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN: f64 = 5.0;
414
415    fn is_map_faster_than_search(&self, take_indices: &PrimitiveArray) -> bool {
416        (self.num_patches() as f64 / take_indices.len() as f64)
417            < Self::PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN
418    }
419
420    /// Take the indices from the patches
421    ///
422    /// Any nulls in take_indices are added to the resulting patches.
423    pub fn take_with_nulls(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
424        if take_indices.is_empty() {
425            return Ok(None);
426        }
427
428        let take_indices = take_indices.to_primitive();
429        if self.is_map_faster_than_search(&take_indices) {
430            self.take_map(take_indices, true)
431        } else {
432            self.take_search(take_indices, true)
433        }
434    }
435
436    /// Take the indices from the patches.
437    ///
438    /// Any nulls in take_indices are ignored.
439    pub fn take(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
440        if take_indices.is_empty() {
441            return Ok(None);
442        }
443
444        let take_indices = take_indices.to_primitive();
445        if self.is_map_faster_than_search(&take_indices) {
446            self.take_map(take_indices, false)
447        } else {
448            self.take_search(take_indices, false)
449        }
450    }
451
452    pub fn take_search(
453        &self,
454        take_indices: PrimitiveArray,
455        include_nulls: bool,
456    ) -> VortexResult<Option<Self>> {
457        let indices = self.indices.to_primitive();
458        let new_length = take_indices.len();
459
460        let Some((new_indices, values_indices)) =
461            match_each_unsigned_integer_ptype!(indices.ptype(), |Indices| {
462                match_each_integer_ptype!(take_indices.ptype(), |TakeIndices| {
463                    take_search::<_, TakeIndices>(
464                        indices.as_slice::<Indices>(),
465                        take_indices,
466                        self.offset(),
467                        include_nulls,
468                    )?
469                })
470            })
471        else {
472            return Ok(None);
473        };
474
475        Ok(Some(Self::new(
476            new_length,
477            0,
478            new_indices,
479            take(self.values(), &values_indices)?,
480        )))
481    }
482
483    pub fn take_map(
484        &self,
485        take_indices: PrimitiveArray,
486        include_nulls: bool,
487    ) -> VortexResult<Option<Self>> {
488        let indices = self.indices.to_primitive();
489        let new_length = take_indices.len();
490
491        let Some((new_sparse_indices, value_indices)) =
492            match_each_unsigned_integer_ptype!(indices.ptype(), |Indices| {
493                match_each_integer_ptype!(take_indices.ptype(), |TakeIndices| {
494                    take_map::<_, TakeIndices>(
495                        indices.as_slice::<Indices>(),
496                        take_indices,
497                        self.offset(),
498                        self.min_index(),
499                        self.max_index(),
500                        include_nulls,
501                    )?
502                })
503            })
504        else {
505            return Ok(None);
506        };
507
508        Ok(Some(Patches::new(
509            new_length,
510            0,
511            new_sparse_indices,
512            take(self.values(), &value_indices)?,
513        )))
514    }
515
516    pub fn map_values<F>(self, f: F) -> VortexResult<Self>
517    where
518        F: FnOnce(ArrayRef) -> VortexResult<ArrayRef>,
519    {
520        let values = f(self.values)?;
521        if self.indices.len() != values.len() {
522            vortex_bail!(
523                "map_values must preserve length: expected {} received {}",
524                self.indices.len(),
525                values.len()
526            )
527        }
528        Ok(Self::new(self.array_len, self.offset, self.indices, values))
529    }
530}
531
532fn take_search<I: NativePType + NumCast + PartialOrd, T: NativePType + NumCast>(
533    indices: &[I],
534    take_indices: PrimitiveArray,
535    indices_offset: usize,
536    include_nulls: bool,
537) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
538where
539    usize: TryFrom<T>,
540    VortexError: From<<usize as TryFrom<T>>::Error>,
541{
542    let take_indices_validity = take_indices.validity();
543    let indices_offset = I::from(indices_offset).vortex_expect("indices_offset out of range");
544
545    let (values_indices, new_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
546        .as_slice::<T>()
547        .iter()
548        .enumerate()
549        .filter_map(|(i, &v)| {
550            I::from(v)
551                .and_then(|v| {
552                    // If we have to take nulls the take index doesn't matter, make it 0 for consistency
553                    if include_nulls && take_indices_validity.is_null(i) {
554                        Some(0)
555                    } else {
556                        indices
557                            .search_sorted(&(v + indices_offset), SearchSortedSide::Left)
558                            .to_found()
559                            .map(|patch_idx| patch_idx as u64)
560                    }
561                })
562                .map(|patch_idx| (patch_idx, i as u64))
563        })
564        .unzip();
565
566    if new_indices.is_empty() {
567        return Ok(None);
568    }
569
570    let new_indices = new_indices.into_array();
571    let values_validity = take_indices_validity.take(&new_indices)?;
572    Ok(Some((
573        new_indices,
574        PrimitiveArray::new(values_indices, values_validity).into_array(),
575    )))
576}
577
578fn take_map<I: NativePType + Hash + Eq + TryFrom<usize>, T: NativePType>(
579    indices: &[I],
580    take_indices: PrimitiveArray,
581    indices_offset: usize,
582    min_index: usize,
583    max_index: usize,
584    include_nulls: bool,
585) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
586where
587    usize: TryFrom<T>,
588    VortexError: From<<I as TryFrom<usize>>::Error>,
589{
590    let take_indices_validity = take_indices.validity();
591    let take_indices = take_indices.as_slice::<T>();
592    let offset_i = I::try_from(indices_offset)?;
593
594    let sparse_index_to_value_index: HashMap<I, usize> = indices
595        .iter()
596        .copied()
597        .map(|idx| idx - offset_i)
598        .enumerate()
599        .map(|(value_index, sparse_index)| (sparse_index, value_index))
600        .collect();
601
602    let (new_sparse_indices, value_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
603        .iter()
604        .copied()
605        .map(usize::try_from)
606        .process_results(|iter| {
607            iter.enumerate()
608                .filter_map(|(idx_in_take, ti)| {
609                    // If we have to take nulls the take index doesn't matter, make it 0 for consistency
610                    if include_nulls && take_indices_validity.is_null(idx_in_take) {
611                        Some((idx_in_take as u64, 0))
612                    } else if ti < min_index || ti > max_index {
613                        None
614                    } else {
615                        sparse_index_to_value_index
616                            .get(
617                                &I::try_from(ti)
618                                    .vortex_expect("take index is between min and max index"),
619                            )
620                            .map(|value_index| (idx_in_take as u64, *value_index as u64))
621                    }
622                })
623                .unzip()
624        })
625        .map_err(|_| vortex_err!("Failed to convert index to usize"))?;
626
627    if new_sparse_indices.is_empty() {
628        return Ok(None);
629    }
630
631    let new_sparse_indices = new_sparse_indices.into_array();
632    let values_validity = take_indices_validity.take(&new_sparse_indices)?;
633    Ok(Some((
634        new_sparse_indices,
635        PrimitiveArray::new(value_indices, values_validity).into_array(),
636    )))
637}
638
639/// Filter patches with the provided mask (in flattened space).
640///
641/// The filter mask may contain indices that are non-patched. The return value of this function
642/// is a new set of `Patches` with the indices relative to the provided `mask` rank, and the
643/// patch values.
644fn filter_patches_with_mask<T: ToPrimitive + Copy + Ord>(
645    patch_indices: &[T],
646    offset: usize,
647    patch_values: &dyn Array,
648    mask_indices: &[usize],
649) -> VortexResult<Option<Patches>> {
650    let true_count = mask_indices.len();
651    let mut new_patch_indices = BufferMut::<u64>::with_capacity(true_count);
652    let mut new_mask_indices = Vec::with_capacity(true_count);
653
654    // Attempt to move the window by `STRIDE` elements on each iteration. This assumes that
655    // the patches are relatively sparse compared to the overall mask, and so many indices in the
656    // mask will end up being skipped.
657    const STRIDE: usize = 4;
658
659    let mut mask_idx = 0usize;
660    let mut true_idx = 0usize;
661
662    while mask_idx < patch_indices.len() && true_idx < true_count {
663        // NOTE: we are searching for overlaps between sorted, unaligned indices in `patch_indices`
664        //  and `mask_indices`. We assume that Patches are sparse relative to the global space of
665        //  the mask (which covers both patch and non-patch values of the parent array), and so to
666        //  quickly jump through regions with no overlap, we attempt to move our pointers by STRIDE
667        //  elements on each iteration. If we cannot rule out overlap due to min/max values, we
668        //  fallback to performing a two-way iterator merge.
669        if (mask_idx + STRIDE) < patch_indices.len() && (true_idx + STRIDE) < mask_indices.len() {
670            // Load a vector of each into our registers.
671            let left_min = patch_indices[mask_idx].to_usize().vortex_expect("left_min") - offset;
672            let left_max = patch_indices[mask_idx + STRIDE]
673                .to_usize()
674                .vortex_expect("left_max")
675                - offset;
676            let right_min = mask_indices[true_idx];
677            let right_max = mask_indices[true_idx + STRIDE];
678
679            if left_min > right_max {
680                // Advance right side
681                true_idx += STRIDE;
682                continue;
683            } else if right_min > left_max {
684                mask_idx += STRIDE;
685                continue;
686            } else {
687                // Fallthrough to direct comparison path.
688            }
689        }
690
691        // Two-way sorted iterator merge:
692
693        let left = patch_indices[mask_idx].to_usize().vortex_expect("left") - offset;
694        let right = mask_indices[true_idx];
695
696        match left.cmp(&right) {
697            Ordering::Less => {
698                mask_idx += 1;
699            }
700            Ordering::Greater => {
701                true_idx += 1;
702            }
703            Ordering::Equal => {
704                // Save the mask index as well as the positional index.
705                new_mask_indices.push(mask_idx);
706                new_patch_indices.push(true_idx as u64);
707
708                mask_idx += 1;
709                true_idx += 1;
710            }
711        }
712    }
713
714    if new_mask_indices.is_empty() {
715        return Ok(None);
716    }
717
718    let new_patch_indices = new_patch_indices.into_array();
719    let new_patch_values = filter(
720        patch_values,
721        &Mask::from_indices(patch_values.len(), new_mask_indices),
722    )?;
723
724    Ok(Some(Patches::new(
725        true_count,
726        0,
727        new_patch_indices,
728        new_patch_values,
729    )))
730}
731
732#[cfg(test)]
733mod test {
734    use rstest::{fixture, rstest};
735    use vortex_buffer::buffer;
736    use vortex_mask::Mask;
737
738    use crate::arrays::PrimitiveArray;
739    use crate::patches::Patches;
740    use crate::search_sorted::{SearchResult, SearchSortedSide};
741    use crate::validity::Validity;
742    use crate::{IntoArray, ToCanonical};
743
744    #[test]
745    fn test_filter() {
746        let patches = Patches::new(
747            100,
748            0,
749            buffer![10u32, 11, 20].into_array(),
750            buffer![100, 110, 200].into_array(),
751        );
752
753        let filtered = patches
754            .filter(&Mask::from_indices(100, vec![10, 20, 30]))
755            .unwrap()
756            .unwrap();
757
758        let indices = filtered.indices().to_primitive();
759        let values = filtered.values().to_primitive();
760        assert_eq!(indices.as_slice::<u64>(), &[0, 1]);
761        assert_eq!(values.as_slice::<i32>(), &[100, 200]);
762    }
763
764    #[fixture]
765    fn patches() -> Patches {
766        Patches::new(
767            20,
768            0,
769            buffer![2u64, 9, 15].into_array(),
770            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
771        )
772    }
773
774    #[rstest]
775    fn search_larger_than(patches: Patches) {
776        let res = patches.search_sorted(66, SearchSortedSide::Left).unwrap();
777        assert_eq!(res, SearchResult::NotFound(16));
778    }
779
780    #[rstest]
781    fn search_less_than(patches: Patches) {
782        let res = patches.search_sorted(22, SearchSortedSide::Left).unwrap();
783        assert_eq!(res, SearchResult::NotFound(2));
784    }
785
786    #[rstest]
787    fn search_found(patches: Patches) {
788        let res = patches.search_sorted(44, SearchSortedSide::Left).unwrap();
789        assert_eq!(res, SearchResult::Found(9));
790    }
791
792    #[rstest]
793    fn search_not_found_right(patches: Patches) {
794        let res = patches.search_sorted(56, SearchSortedSide::Right).unwrap();
795        assert_eq!(res, SearchResult::NotFound(16));
796    }
797
798    #[rstest]
799    fn search_sliced(patches: Patches) {
800        let sliced = patches.slice(7..20).unwrap();
801        assert_eq!(
802            sliced.search_sorted(22, SearchSortedSide::Left).unwrap(),
803            SearchResult::NotFound(2)
804        );
805    }
806
807    #[test]
808    fn search_right() {
809        let patches = Patches::new(
810            6,
811            0,
812            buffer![0u8, 1, 4, 5].into_array(),
813            buffer![-128i8, -98, 8, 50].into_array(),
814        );
815
816        assert_eq!(
817            patches.search_sorted(-98, SearchSortedSide::Right).unwrap(),
818            SearchResult::Found(2)
819        );
820        assert_eq!(
821            patches.search_sorted(50, SearchSortedSide::Right).unwrap(),
822            SearchResult::Found(6),
823        );
824        assert_eq!(
825            patches.search_sorted(7, SearchSortedSide::Right).unwrap(),
826            SearchResult::NotFound(2),
827        );
828        assert_eq!(
829            patches.search_sorted(51, SearchSortedSide::Right).unwrap(),
830            SearchResult::NotFound(6)
831        );
832    }
833
834    #[test]
835    fn search_left() {
836        let patches = Patches::new(
837            20,
838            0,
839            buffer![0u64, 1, 17, 18, 19].into_array(),
840            buffer![11i32, 22, 33, 44, 55].into_array(),
841        );
842        assert_eq!(
843            patches.search_sorted(30, SearchSortedSide::Left).unwrap(),
844            SearchResult::NotFound(2)
845        );
846        assert_eq!(
847            patches.search_sorted(54, SearchSortedSide::Left).unwrap(),
848            SearchResult::NotFound(19)
849        );
850    }
851
852    #[rstest]
853    fn take_with_nulls(patches: Patches) {
854        let taken = patches
855            .take(
856                &PrimitiveArray::new(buffer![9, 0], Validity::from_iter(vec![true, false]))
857                    .into_array(),
858            )
859            .unwrap()
860            .unwrap();
861        let primitive_values = taken.values().to_primitive();
862        assert_eq!(taken.array_len(), 2);
863        assert_eq!(primitive_values.as_slice::<i32>(), [44]);
864        assert_eq!(
865            primitive_values.validity_mask(),
866            Mask::from_iter(vec![true])
867        );
868    }
869
870    #[test]
871    fn test_slice() {
872        let values = buffer![15_u32, 135, 13531, 42].into_array();
873        let indices = buffer![10_u64, 11, 50, 100].into_array();
874
875        let patches = Patches::new(101, 0, indices, values);
876
877        let sliced = patches.slice(15..100).unwrap();
878        assert_eq!(sliced.array_len(), 100 - 15);
879        let primitive = sliced.values().to_primitive();
880
881        assert_eq!(primitive.as_slice::<u32>(), &[13531]);
882    }
883
884    #[test]
885    fn doubly_sliced() {
886        let values = buffer![15_u32, 135, 13531, 42].into_array();
887        let indices = buffer![10_u64, 11, 50, 100].into_array();
888
889        let patches = Patches::new(101, 0, indices, values);
890
891        let sliced = patches.slice(15..100).unwrap();
892        assert_eq!(sliced.array_len(), 100 - 15);
893        let primitive = sliced.values().to_primitive();
894
895        assert_eq!(primitive.as_slice::<u32>(), &[13531]);
896
897        let doubly_sliced = sliced.slice(35..36).unwrap();
898        let primitive_doubly_sliced = doubly_sliced.values().to_primitive();
899
900        assert_eq!(primitive_doubly_sliced.as_slice::<u32>(), &[13531]);
901    }
902
903    #[test]
904    fn test_mask_all_true() {
905        let patches = Patches::new(
906            10,
907            0,
908            buffer![2u64, 5, 8].into_array(),
909            buffer![100i32, 200, 300].into_array(),
910        );
911
912        let mask = Mask::new_true(10);
913        let masked = patches.mask(&mask).unwrap();
914        assert!(masked.is_none());
915    }
916
917    #[test]
918    fn test_mask_all_false() {
919        let patches = Patches::new(
920            10,
921            0,
922            buffer![2u64, 5, 8].into_array(),
923            buffer![100i32, 200, 300].into_array(),
924        );
925
926        let mask = Mask::new_false(10);
927        let masked = patches.mask(&mask).unwrap().unwrap();
928
929        // No patch values should be masked
930        let masked_values = masked.values().to_primitive();
931        assert_eq!(masked_values.as_slice::<i32>(), &[100, 200, 300]);
932        assert!(masked_values.is_valid(0));
933        assert!(masked_values.is_valid(1));
934        assert!(masked_values.is_valid(2));
935
936        // Indices should remain unchanged
937        let indices = masked.indices().to_primitive();
938        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
939    }
940
941    #[test]
942    fn test_mask_partial() {
943        let patches = Patches::new(
944            10,
945            0,
946            buffer![2u64, 5, 8].into_array(),
947            buffer![100i32, 200, 300].into_array(),
948        );
949
950        // Mask that removes patches at indices 2 and 8 (but not 5)
951        let mask = Mask::from_iter([
952            false, false, true, false, false, false, false, false, true, false,
953        ]);
954        let masked = patches.mask(&mask).unwrap().unwrap();
955
956        // Only the patch at index 5 should remain
957        let masked_values = masked.values().to_primitive();
958        assert_eq!(masked_values.len(), 1);
959        assert_eq!(masked_values.as_slice::<i32>(), &[200]);
960
961        // Only index 5 should remain
962        let indices = masked.indices().to_primitive();
963        assert_eq!(indices.as_slice::<u64>(), &[5]);
964    }
965
966    #[test]
967    fn test_mask_with_offset() {
968        let patches = Patches::new(
969            10,
970            5,                                  // offset
971            buffer![7u64, 10, 13].into_array(), // actual indices are 2, 5, 8
972            buffer![100i32, 200, 300].into_array(),
973        );
974
975        // Mask that sets actual index 2 to null
976        let mask = Mask::from_iter([
977            false, false, true, false, false, false, false, false, false, false,
978        ]);
979
980        let masked = patches.mask(&mask).unwrap().unwrap();
981        assert_eq!(masked.array_len(), 10);
982        assert_eq!(masked.offset(), 5);
983        let indices = masked.indices().to_primitive();
984        assert_eq!(indices.as_slice::<u64>(), &[10, 13]);
985        let masked_values = masked.values().to_primitive();
986        assert_eq!(masked_values.as_slice::<i32>(), &[200, 300]);
987    }
988
989    #[test]
990    fn test_mask_nullable_values() {
991        let patches = Patches::new(
992            10,
993            0,
994            buffer![2u64, 5, 8].into_array(),
995            PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
996        );
997
998        // Test masking removes patch at index 2
999        let mask = Mask::from_iter([
1000            false, false, true, false, false, false, false, false, false, false,
1001        ]);
1002        let masked = patches.mask(&mask).unwrap().unwrap();
1003
1004        // Patches at indices 5 and 8 should remain
1005        let indices = masked.indices().to_primitive();
1006        assert_eq!(indices.as_slice::<u64>(), &[5, 8]);
1007
1008        // Values should be the null and 300
1009        let masked_values = masked.values().to_primitive();
1010        assert_eq!(masked_values.len(), 2);
1011        assert!(!masked_values.is_valid(0)); // the null value at index 5
1012        assert!(masked_values.is_valid(1)); // the 300 value at index 8
1013        assert_eq!(i32::try_from(&masked_values.scalar_at(1)).unwrap(), 300i32);
1014    }
1015
1016    #[test]
1017    fn test_filter_keep_all() {
1018        let patches = Patches::new(
1019            10,
1020            0,
1021            buffer![2u64, 5, 8].into_array(),
1022            buffer![100i32, 200, 300].into_array(),
1023        );
1024
1025        // Keep all indices (mask with indices 0-9)
1026        let mask = Mask::from_indices(10, (0..10).collect());
1027        let filtered = patches.filter(&mask).unwrap().unwrap();
1028
1029        let indices = filtered.indices().to_primitive();
1030        let values = filtered.values().to_primitive();
1031        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
1032        assert_eq!(values.as_slice::<i32>(), &[100, 200, 300]);
1033    }
1034
1035    #[test]
1036    fn test_filter_none() {
1037        let patches = Patches::new(
1038            10,
1039            0,
1040            buffer![2u64, 5, 8].into_array(),
1041            buffer![100i32, 200, 300].into_array(),
1042        );
1043
1044        // Filter out all (empty mask means keep nothing)
1045        let mask = Mask::from_indices(10, vec![]);
1046        let filtered = patches.filter(&mask).unwrap();
1047        assert!(filtered.is_none());
1048    }
1049
1050    #[test]
1051    fn test_filter_with_indices() {
1052        let patches = Patches::new(
1053            10,
1054            0,
1055            buffer![2u64, 5, 8].into_array(),
1056            buffer![100i32, 200, 300].into_array(),
1057        );
1058
1059        // Keep indices 2, 5, 9 (so patches at 2 and 5 remain)
1060        let mask = Mask::from_indices(10, vec![2, 5, 9]);
1061        let filtered = patches.filter(&mask).unwrap().unwrap();
1062
1063        let indices = filtered.indices().to_primitive();
1064        let values = filtered.values().to_primitive();
1065        assert_eq!(indices.as_slice::<u64>(), &[0, 1]); // Adjusted indices
1066        assert_eq!(values.as_slice::<i32>(), &[100, 200]);
1067    }
1068
1069    #[test]
1070    fn test_slice_full_range() {
1071        let patches = Patches::new(
1072            10,
1073            0,
1074            buffer![2u64, 5, 8].into_array(),
1075            buffer![100i32, 200, 300].into_array(),
1076        );
1077
1078        let sliced = patches.slice(0..10).unwrap();
1079
1080        let indices = sliced.indices().to_primitive();
1081        let values = sliced.values().to_primitive();
1082        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
1083        assert_eq!(values.as_slice::<i32>(), &[100, 200, 300]);
1084    }
1085
1086    #[test]
1087    fn test_slice_partial() {
1088        let patches = Patches::new(
1089            10,
1090            0,
1091            buffer![2u64, 5, 8].into_array(),
1092            buffer![100i32, 200, 300].into_array(),
1093        );
1094
1095        // Slice from 3 to 8 (includes patch at 5)
1096        let sliced = patches.slice(3..8).unwrap();
1097
1098        let indices = sliced.indices().to_primitive();
1099        let values = sliced.values().to_primitive();
1100        assert_eq!(indices.as_slice::<u64>(), &[5]); // Index stays the same
1101        assert_eq!(values.as_slice::<i32>(), &[200]);
1102        assert_eq!(sliced.array_len(), 5); // 8 - 3 = 5
1103        assert_eq!(sliced.offset(), 3); // New offset
1104    }
1105
1106    #[test]
1107    fn test_slice_no_patches() {
1108        let patches = Patches::new(
1109            10,
1110            0,
1111            buffer![2u64, 5, 8].into_array(),
1112            buffer![100i32, 200, 300].into_array(),
1113        );
1114
1115        // Slice from 6 to 7 (no patches in this range)
1116        let sliced = patches.slice(6..7);
1117        assert!(sliced.is_none());
1118    }
1119
1120    #[test]
1121    fn test_slice_with_offset() {
1122        let patches = Patches::new(
1123            10,
1124            5,                                  // offset
1125            buffer![7u64, 10, 13].into_array(), // actual indices are 2, 5, 8
1126            buffer![100i32, 200, 300].into_array(),
1127        );
1128
1129        // Slice from 3 to 8 (includes patch at actual index 5)
1130        let sliced = patches.slice(3..8).unwrap();
1131
1132        let indices = sliced.indices().to_primitive();
1133        let values = sliced.values().to_primitive();
1134        assert_eq!(indices.as_slice::<u64>(), &[10]); // Index stays the same (offset + 5 = 10)
1135        assert_eq!(values.as_slice::<i32>(), &[200]);
1136        assert_eq!(sliced.offset(), 8); // New offset = 5 + 3
1137    }
1138
1139    #[test]
1140    fn test_patch_values() {
1141        let patches = Patches::new(
1142            10,
1143            0,
1144            buffer![2u64, 5, 8].into_array(),
1145            buffer![100i32, 200, 300].into_array(),
1146        );
1147
1148        let values = patches.values().to_primitive();
1149        assert_eq!(i32::try_from(&values.scalar_at(0)).unwrap(), 100i32);
1150        assert_eq!(i32::try_from(&values.scalar_at(1)).unwrap(), 200i32);
1151        assert_eq!(i32::try_from(&values.scalar_at(2)).unwrap(), 300i32);
1152    }
1153
1154    #[test]
1155    fn test_indices_range() {
1156        let patches = Patches::new(
1157            10,
1158            0,
1159            buffer![2u64, 5, 8].into_array(),
1160            buffer![100i32, 200, 300].into_array(),
1161        );
1162
1163        assert_eq!(patches.min_index(), 2);
1164        assert_eq!(patches.max_index(), 8);
1165    }
1166
1167    #[test]
1168    fn test_search_index() {
1169        let patches = Patches::new(
1170            10,
1171            0,
1172            buffer![2u64, 5, 8].into_array(),
1173            buffer![100i32, 200, 300].into_array(),
1174        );
1175
1176        // Search for exact indices
1177        assert_eq!(patches.search_index(2), SearchResult::Found(0));
1178        assert_eq!(patches.search_index(5), SearchResult::Found(1));
1179        assert_eq!(patches.search_index(8), SearchResult::Found(2));
1180
1181        // Search for non-patch indices
1182        assert_eq!(patches.search_index(0), SearchResult::NotFound(0));
1183        assert_eq!(patches.search_index(3), SearchResult::NotFound(1));
1184        assert_eq!(patches.search_index(6), SearchResult::NotFound(2));
1185        assert_eq!(patches.search_index(9), SearchResult::NotFound(3));
1186    }
1187
1188    #[test]
1189    fn test_mask_boundary_patches() {
1190        // Test masking patches at array boundaries
1191        let patches = Patches::new(
1192            10,
1193            0,
1194            buffer![0u64, 9].into_array(),
1195            buffer![100i32, 200].into_array(),
1196        );
1197
1198        let mask = Mask::from_iter([
1199            true, false, false, false, false, false, false, false, false, false,
1200        ]);
1201        let masked = patches.mask(&mask).unwrap();
1202        assert!(masked.is_some());
1203        let masked = masked.unwrap();
1204        let indices = masked.indices().to_primitive();
1205        assert_eq!(indices.as_slice::<u64>(), &[9]);
1206        let values = masked.values().to_primitive();
1207        assert_eq!(values.as_slice::<i32>(), &[200]);
1208    }
1209
1210    #[test]
1211    fn test_mask_all_patches_removed() {
1212        // Test when all patches are masked out
1213        let patches = Patches::new(
1214            10,
1215            0,
1216            buffer![2u64, 5, 8].into_array(),
1217            buffer![100i32, 200, 300].into_array(),
1218        );
1219
1220        // Mask that removes all patches
1221        let mask = Mask::from_iter([
1222            false, false, true, false, false, true, false, false, true, false,
1223        ]);
1224        let masked = patches.mask(&mask).unwrap();
1225        assert!(masked.is_none());
1226    }
1227
1228    #[test]
1229    fn test_mask_no_patches_removed() {
1230        // Test when no patches are masked
1231        let patches = Patches::new(
1232            10,
1233            0,
1234            buffer![2u64, 5, 8].into_array(),
1235            buffer![100i32, 200, 300].into_array(),
1236        );
1237
1238        // Mask that doesn't affect any patches
1239        let mask = Mask::from_iter([
1240            true, false, false, true, false, false, true, false, false, true,
1241        ]);
1242        let masked = patches.mask(&mask).unwrap().unwrap();
1243
1244        let indices = masked.indices().to_primitive();
1245        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
1246        let values = masked.values().to_primitive();
1247        assert_eq!(values.as_slice::<i32>(), &[100, 200, 300]);
1248    }
1249
1250    #[test]
1251    fn test_mask_single_patch() {
1252        // Test with a single patch
1253        let patches = Patches::new(
1254            5,
1255            0,
1256            buffer![2u64].into_array(),
1257            buffer![42i32].into_array(),
1258        );
1259
1260        // Mask that removes the single patch
1261        let mask = Mask::from_iter([false, false, true, false, false]);
1262        let masked = patches.mask(&mask).unwrap();
1263        assert!(masked.is_none());
1264
1265        // Mask that keeps the single patch
1266        let mask = Mask::from_iter([true, false, false, true, false]);
1267        let masked = patches.mask(&mask).unwrap().unwrap();
1268        let indices = masked.indices().to_primitive();
1269        assert_eq!(indices.as_slice::<u64>(), &[2]);
1270    }
1271
1272    #[test]
1273    fn test_mask_contiguous_patches() {
1274        // Test with contiguous patches
1275        let patches = Patches::new(
1276            10,
1277            0,
1278            buffer![3u64, 4, 5, 6].into_array(),
1279            buffer![100i32, 200, 300, 400].into_array(),
1280        );
1281
1282        // Mask that removes middle patches
1283        let mask = Mask::from_iter([
1284            false, false, false, false, true, true, false, false, false, false,
1285        ]);
1286        let masked = patches.mask(&mask).unwrap().unwrap();
1287
1288        let indices = masked.indices().to_primitive();
1289        assert_eq!(indices.as_slice::<u64>(), &[3, 6]);
1290        let values = masked.values().to_primitive();
1291        assert_eq!(values.as_slice::<i32>(), &[100, 400]);
1292    }
1293
1294    #[test]
1295    fn test_mask_with_large_offset() {
1296        // Test with a large offset that shifts all indices
1297        let patches = Patches::new(
1298            20,
1299            15,
1300            buffer![16u64, 17, 19].into_array(), // actual indices are 1, 2, 4
1301            buffer![100i32, 200, 300].into_array(),
1302        );
1303
1304        // Mask that removes the patch at actual index 2
1305        let mask = Mask::from_iter([
1306            false, false, true, false, false, false, false, false, false, false, false, false,
1307            false, false, false, false, false, false, false, false,
1308        ]);
1309        let masked = patches.mask(&mask).unwrap().unwrap();
1310
1311        let indices = masked.indices().to_primitive();
1312        assert_eq!(indices.as_slice::<u64>(), &[16, 19]);
1313        let values = masked.values().to_primitive();
1314        assert_eq!(values.as_slice::<i32>(), &[100, 300]);
1315    }
1316
1317    #[test]
1318    #[should_panic(expected = "Filter mask length 5 does not match array length 10")]
1319    fn test_mask_wrong_length() {
1320        let patches = Patches::new(
1321            10,
1322            0,
1323            buffer![2u64, 5, 8].into_array(),
1324            buffer![100i32, 200, 300].into_array(),
1325        );
1326
1327        // Mask with wrong length
1328        let mask = Mask::from_iter([false, false, true, false, false]);
1329        let _ = patches.mask(&mask).unwrap();
1330    }
1331}