vortex_array/
patches.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::fmt::Debug;
6use std::hash::Hash;
7use std::ops::Range;
8
9use arrow_buffer::BooleanBuffer;
10use itertools::Itertools as _;
11use num_traits::{NumCast, ToPrimitive};
12use serde::{Deserialize, Serialize};
13use vortex_buffer::BufferMut;
14use vortex_dtype::Nullability::NonNullable;
15use vortex_dtype::{
16    DType, NativePType, PType, match_each_integer_ptype, match_each_unsigned_integer_ptype,
17};
18use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
19use vortex_mask::{AllOr, Mask};
20use vortex_scalar::{PValue, Scalar};
21use vortex_utils::aliases::hash_map::HashMap;
22
23use crate::arrays::PrimitiveArray;
24use crate::compute::{cast, filter, take};
25use crate::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
26use crate::vtable::ValidityHelper;
27use crate::{Array, ArrayRef, IntoArray, ToCanonical};
28
29#[derive(Copy, Clone, Serialize, Deserialize, prost::Message)]
30pub struct PatchesMetadata {
31    #[prost(uint64, tag = "1")]
32    len: u64,
33    #[prost(uint64, tag = "2")]
34    offset: u64,
35    #[prost(enumeration = "PType", tag = "3")]
36    indices_ptype: i32,
37}
38
39impl PatchesMetadata {
40    pub fn new(len: usize, offset: usize, indices_ptype: PType) -> Self {
41        Self {
42            len: len as u64,
43            offset: offset as u64,
44            indices_ptype: indices_ptype as i32,
45        }
46    }
47
48    #[inline]
49    pub fn len(&self) -> usize {
50        usize::try_from(self.len).vortex_expect("len is a valid usize")
51    }
52
53    #[inline]
54    pub fn is_empty(&self) -> bool {
55        self.len == 0
56    }
57
58    #[inline]
59    pub fn offset(&self) -> usize {
60        usize::try_from(self.offset).vortex_expect("offset is a valid usize")
61    }
62
63    #[inline]
64    pub fn indices_dtype(&self) -> DType {
65        assert!(
66            self.indices_ptype().is_unsigned_int(),
67            "Patch indices must be unsigned integers"
68        );
69        DType::Primitive(self.indices_ptype(), NonNullable)
70    }
71}
72
73/// A helper for working with patched arrays.
74#[derive(Debug, Clone)]
75pub struct Patches {
76    array_len: usize,
77    offset: usize,
78    indices: ArrayRef,
79    values: ArrayRef,
80}
81
82impl Patches {
83    pub fn new(array_len: usize, offset: usize, indices: ArrayRef, values: ArrayRef) -> Self {
84        assert_eq!(
85            indices.len(),
86            values.len(),
87            "Patch indices and values must have the same length"
88        );
89        assert!(
90            indices.dtype().is_unsigned_int() && !indices.dtype().is_nullable(),
91            "Patch indices must be non-nullable unsigned integers, got {:?}",
92            indices.dtype()
93        );
94        assert!(
95            indices.len() <= array_len,
96            "Patch indices must be shorter than the array length"
97        );
98        assert!(!indices.is_empty(), "Patch indices must not be empty");
99        let max = usize::try_from(&indices.scalar_at(indices.len() - 1))
100            .vortex_expect("indices must be a number");
101        assert!(
102            max - offset < array_len,
103            "Patch indices {max:?}, offset {offset} are longer than the array length {array_len}"
104        );
105
106        Self {
107            array_len,
108            offset,
109            indices,
110            values,
111        }
112    }
113
114    /// Construct new patches without validating any of the arguments
115    ///
116    /// # Safety
117    ///
118    /// Users have to assert that
119    /// * Indices and values have the same length
120    /// * Indices is an unsigned integer type
121    /// * Indices must be sorted
122    /// * Last value in indices is smaller than array_len
123    pub unsafe fn new_unchecked(
124        array_len: usize,
125        offset: usize,
126        indices: ArrayRef,
127        values: ArrayRef,
128    ) -> Self {
129        Self {
130            array_len,
131            offset,
132            indices,
133            values,
134        }
135    }
136
137    pub fn array_len(&self) -> usize {
138        self.array_len
139    }
140
141    pub fn num_patches(&self) -> usize {
142        self.indices.len()
143    }
144
145    pub fn dtype(&self) -> &DType {
146        self.values.dtype()
147    }
148
149    pub fn indices(&self) -> &ArrayRef {
150        &self.indices
151    }
152
153    pub fn into_indices(self) -> ArrayRef {
154        self.indices
155    }
156
157    pub fn indices_mut(&mut self) -> &mut ArrayRef {
158        &mut self.indices
159    }
160
161    pub fn values(&self) -> &ArrayRef {
162        &self.values
163    }
164
165    pub fn into_values(self) -> ArrayRef {
166        self.values
167    }
168
169    pub fn values_mut(&mut self) -> &mut ArrayRef {
170        &mut self.values
171    }
172
173    pub fn offset(&self) -> usize {
174        self.offset
175    }
176
177    pub fn indices_ptype(&self) -> PType {
178        PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
179    }
180
181    pub fn to_metadata(&self, len: usize, dtype: &DType) -> VortexResult<PatchesMetadata> {
182        if self.indices.len() > len {
183            vortex_bail!(
184                "Patch indices {} are longer than the array length {}",
185                self.indices.len(),
186                len
187            );
188        }
189        if self.values.dtype() != dtype {
190            vortex_bail!(
191                "Patch values dtype {} does not match array dtype {}",
192                self.values.dtype(),
193                dtype
194            );
195        }
196        Ok(PatchesMetadata {
197            len: self.indices.len() as u64,
198            offset: self.offset as u64,
199            indices_ptype: PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
200                as i32,
201        })
202    }
203
204    pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
205        // SAFETY: casting does not affect the relationship between the indices and values
206        unsafe {
207            Ok(Self::new_unchecked(
208                self.array_len,
209                self.offset,
210                self.indices,
211                cast(&self.values, values_dtype)?,
212            ))
213        }
214    }
215
216    /// Get the patched value at a given index if it exists.
217    pub fn get_patched(&self, index: usize) -> Option<Scalar> {
218        self.search_index(index)
219            .to_found()
220            .map(|patch_idx| self.values().scalar_at(patch_idx))
221    }
222
223    /// Return the insertion point of `index` in the [Self::indices].
224    pub fn search_index(&self, index: usize) -> SearchResult {
225        if self.indices.is_canonical() {
226            let primitive = self
227                .indices
228                .to_primitive()
229                .vortex_expect("indices are canonical and must be primitive");
230            match_each_integer_ptype!(primitive.ptype(), |T| {
231                let Ok(needle) = T::try_from(index + self.offset) else {
232                    // If the needle is not of type T, then it cannot possibly be in this array.
233                    //
234                    // The needle is a non-negative integer (a usize); therefore, it must be larger
235                    // than all values in this array.
236                    return SearchResult::NotFound(primitive.len());
237                };
238                return primitive
239                    .as_slice::<T>()
240                    .search_sorted(&needle, SearchSortedSide::Left);
241            });
242        }
243        self.indices.as_primitive_typed().search_sorted(
244            &PValue::U64((index + self.offset) as u64),
245            SearchSortedSide::Left,
246        )
247    }
248
249    /// Return the search_sorted result for the given target re-mapped into the original indices.
250    pub fn search_sorted<T: Into<Scalar>>(
251        &self,
252        target: T,
253        side: SearchSortedSide,
254    ) -> VortexResult<SearchResult> {
255        let target = target.into();
256
257        let sr = if self.values().dtype().is_primitive() {
258            self.values()
259                .as_primitive_typed()
260                .search_sorted(&target.as_primitive().pvalue(), side)
261        } else {
262            self.values().search_sorted(&target, side)
263        };
264
265        let index_idx = sr.to_offsets_index(self.indices().len(), side);
266        let index = usize::try_from(&self.indices().scalar_at(index_idx))? - self.offset;
267        Ok(match sr {
268            // If we reached the end of patched values when searching then the result is one after the last patch index
269            SearchResult::Found(i) => SearchResult::Found(
270                if i == self.indices().len() || side == SearchSortedSide::Right {
271                    index + 1
272                } else {
273                    index
274                },
275            ),
276            // If the result is NotFound we should return index that's one after the nearest not found index for the corresponding value
277            SearchResult::NotFound(i) => {
278                SearchResult::NotFound(if i == 0 { index } else { index + 1 })
279            }
280        })
281    }
282
283    /// Returns the minimum patch index
284    pub fn min_index(&self) -> usize {
285        let first = self
286            .indices
287            .scalar_at(0)
288            .as_primitive()
289            .as_::<usize>()
290            .vortex_expect("non-null");
291        first - self.offset
292    }
293
294    /// Returns the maximum patch index
295    pub fn max_index(&self) -> usize {
296        let last = self
297            .indices
298            .scalar_at(self.indices.len() - 1)
299            .as_primitive()
300            .as_::<usize>()
301            .vortex_expect("non-null");
302        last - self.offset
303    }
304
305    /// Filter the patches by a mask, resulting in new patches for the filtered array.
306    pub fn filter(&self, mask: &Mask) -> VortexResult<Option<Self>> {
307        if mask.len() != self.array_len {
308            vortex_bail!(
309                "Filter mask length {} does not match array length {}",
310                mask.len(),
311                self.array_len
312            );
313        }
314
315        match mask.indices() {
316            AllOr::All => Ok(Some(self.clone())),
317            AllOr::None => Ok(None),
318            AllOr::Some(mask_indices) => {
319                let flat_indices = self.indices().to_primitive()?;
320                match_each_unsigned_integer_ptype!(flat_indices.ptype(), |I| {
321                    filter_patches_with_mask(
322                        flat_indices.as_slice::<I>(),
323                        self.offset(),
324                        self.values(),
325                        mask_indices,
326                    )
327                })
328            }
329        }
330    }
331
332    /// Mask the patches, REMOVING the patches where the mask is true.
333    /// Unlike filter, this preserves the patch indices.
334    /// Unlike mask on a single array, this does not set masked values to null.
335    pub fn mask(&self, mask: &Mask) -> VortexResult<Option<Self>> {
336        if mask.len() != self.array_len {
337            vortex_bail!(
338                "Filter mask length {} does not match array length {}",
339                mask.len(),
340                self.array_len
341            );
342        }
343
344        let filter_mask = match mask.boolean_buffer() {
345            AllOr::All => return Ok(None),
346            AllOr::None => return Ok(Some(self.clone())),
347            AllOr::Some(masked) => {
348                let patch_indices = self.indices().to_primitive()?;
349                match_each_unsigned_integer_ptype!(patch_indices.ptype(), |P| {
350                    let patch_indices = patch_indices.as_slice::<P>();
351                    Mask::from_buffer(BooleanBuffer::collect_bool(patch_indices.len(), |i| {
352                        #[allow(clippy::cast_possible_truncation)]
353                        let idx = (patch_indices[i] as usize) - self.offset;
354                        !masked.value(idx)
355                    }))
356                })
357            }
358        };
359
360        if filter_mask.all_false() {
361            return Ok(None);
362        }
363
364        // SAFETY: filtering indices/values with same mask maintains their 1:1 relationship
365        unsafe {
366            Ok(Some(Self::new_unchecked(
367                self.array_len,
368                self.offset,
369                filter(&self.indices, &filter_mask)?,
370                filter(&self.values, &filter_mask)?,
371            )))
372        }
373    }
374
375    /// Slice the patches by a range of the patched array.
376    pub fn slice(&self, range: Range<usize>) -> Option<Self> {
377        if range.len() == 1 {
378            let patch_index = self.search_index(range.start).to_found()?;
379            let values = self.values.slice(patch_index..patch_index + 1);
380            let indices = self.indices.slice(patch_index..patch_index + 1);
381            return Some(Self::new(1, range.start + self.offset(), indices, values));
382        }
383
384        let patch_start = self.search_index(range.start).to_index();
385        let patch_stop = self.search_index(range.end).to_index();
386
387        if patch_start == patch_stop {
388            return None;
389        }
390
391        // Slice out the values and indices
392        let values = self.values().slice(patch_start..patch_stop);
393        let indices = self.indices().slice(patch_start..patch_stop);
394
395        Some(Self::new(
396            range.len(),
397            range.start + self.offset(),
398            indices,
399            values,
400        ))
401    }
402
403    // https://docs.google.com/spreadsheets/d/1D9vBZ1QJ6mwcIvV5wIL0hjGgVchcEnAyhvitqWu2ugU
404    const PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN: f64 = 5.0;
405
406    fn is_map_faster_than_search(&self, take_indices: &PrimitiveArray) -> bool {
407        (self.num_patches() as f64 / take_indices.len() as f64)
408            < Self::PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN
409    }
410
411    /// Take the indices from the patches
412    ///
413    /// Any nulls in take_indices are added to the resulting patches.
414    pub fn take_with_nulls(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
415        if take_indices.is_empty() {
416            return Ok(None);
417        }
418
419        let take_indices = take_indices.to_primitive()?;
420        if self.is_map_faster_than_search(&take_indices) {
421            self.take_map(take_indices, true)
422        } else {
423            self.take_search(take_indices, true)
424        }
425    }
426
427    /// Take the indices from the patches.
428    ///
429    /// Any nulls in take_indices are ignored.
430    pub fn take(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
431        if take_indices.is_empty() {
432            return Ok(None);
433        }
434
435        let take_indices = take_indices.to_primitive()?;
436        if self.is_map_faster_than_search(&take_indices) {
437            self.take_map(take_indices, false)
438        } else {
439            self.take_search(take_indices, false)
440        }
441    }
442
443    pub fn take_search(
444        &self,
445        take_indices: PrimitiveArray,
446        include_nulls: bool,
447    ) -> VortexResult<Option<Self>> {
448        let indices = self.indices.to_primitive()?;
449        let new_length = take_indices.len();
450
451        let Some((new_indices, values_indices)) =
452            match_each_unsigned_integer_ptype!(indices.ptype(), |Indices| {
453                match_each_integer_ptype!(take_indices.ptype(), |TakeIndices| {
454                    take_search::<_, TakeIndices>(
455                        indices.as_slice::<Indices>(),
456                        take_indices,
457                        self.offset(),
458                        include_nulls,
459                    )?
460                })
461            })
462        else {
463            return Ok(None);
464        };
465
466        Ok(Some(Self::new(
467            new_length,
468            0,
469            new_indices,
470            take(self.values(), &values_indices)?,
471        )))
472    }
473
474    pub fn take_map(
475        &self,
476        take_indices: PrimitiveArray,
477        include_nulls: bool,
478    ) -> VortexResult<Option<Self>> {
479        let indices = self.indices.to_primitive()?;
480        let new_length = take_indices.len();
481
482        let Some((new_sparse_indices, value_indices)) =
483            match_each_unsigned_integer_ptype!(indices.ptype(), |Indices| {
484                match_each_integer_ptype!(take_indices.ptype(), |TakeIndices| {
485                    take_map::<_, TakeIndices>(
486                        indices.as_slice::<Indices>(),
487                        take_indices,
488                        self.offset(),
489                        self.min_index(),
490                        self.max_index(),
491                        include_nulls,
492                    )?
493                })
494            })
495        else {
496            return Ok(None);
497        };
498
499        Ok(Some(Patches::new(
500            new_length,
501            0,
502            new_sparse_indices,
503            take(self.values(), &value_indices)?,
504        )))
505    }
506
507    pub fn map_values<F>(self, f: F) -> VortexResult<Self>
508    where
509        F: FnOnce(ArrayRef) -> VortexResult<ArrayRef>,
510    {
511        let values = f(self.values)?;
512        if self.indices.len() != values.len() {
513            vortex_bail!(
514                "map_values must preserve length: expected {} received {}",
515                self.indices.len(),
516                values.len()
517            )
518        }
519        Ok(Self::new(self.array_len, self.offset, self.indices, values))
520    }
521}
522
523fn take_search<I: NativePType + NumCast + PartialOrd, T: NativePType + NumCast>(
524    indices: &[I],
525    take_indices: PrimitiveArray,
526    indices_offset: usize,
527    include_nulls: bool,
528) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
529where
530    usize: TryFrom<T>,
531    VortexError: From<<usize as TryFrom<T>>::Error>,
532{
533    let take_indices_validity = take_indices.validity();
534    let indices_offset = I::from(indices_offset).vortex_expect("indices_offset out of range");
535
536    let (values_indices, new_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
537        .as_slice::<T>()
538        .iter()
539        .enumerate()
540        .filter_map(|(i, &v)| {
541            I::from(v)
542                .and_then(|v| {
543                    // If we have to take nulls the take index doesn't matter, make it 0 for consistency
544                    if include_nulls && take_indices_validity.is_null(i) {
545                        Some(0)
546                    } else {
547                        indices
548                            .search_sorted(&(v + indices_offset), SearchSortedSide::Left)
549                            .to_found()
550                            .map(|patch_idx| patch_idx as u64)
551                    }
552                })
553                .map(|patch_idx| (patch_idx, i as u64))
554        })
555        .unzip();
556
557    if new_indices.is_empty() {
558        return Ok(None);
559    }
560
561    let new_indices = new_indices.into_array();
562    let values_validity = take_indices_validity.take(&new_indices)?;
563    Ok(Some((
564        new_indices,
565        PrimitiveArray::new(values_indices, values_validity).into_array(),
566    )))
567}
568
569fn take_map<I: NativePType + Hash + Eq + TryFrom<usize>, T: NativePType>(
570    indices: &[I],
571    take_indices: PrimitiveArray,
572    indices_offset: usize,
573    min_index: usize,
574    max_index: usize,
575    include_nulls: bool,
576) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
577where
578    usize: TryFrom<T>,
579    VortexError: From<<I as TryFrom<usize>>::Error>,
580{
581    let take_indices_validity = take_indices.validity();
582    let take_indices = take_indices.as_slice::<T>();
583    let offset_i = I::try_from(indices_offset)?;
584
585    let sparse_index_to_value_index: HashMap<I, usize> = indices
586        .iter()
587        .copied()
588        .map(|idx| idx - offset_i)
589        .enumerate()
590        .map(|(value_index, sparse_index)| (sparse_index, value_index))
591        .collect();
592
593    let (new_sparse_indices, value_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
594        .iter()
595        .copied()
596        .map(usize::try_from)
597        .process_results(|iter| {
598            iter.enumerate()
599                .filter_map(|(idx_in_take, ti)| {
600                    // If we have to take nulls the take index doesn't matter, make it 0 for consistency
601                    if include_nulls && take_indices_validity.is_null(idx_in_take) {
602                        Some((idx_in_take as u64, 0))
603                    } else if ti < min_index || ti > max_index {
604                        None
605                    } else {
606                        sparse_index_to_value_index
607                            .get(
608                                &I::try_from(ti)
609                                    .vortex_expect("take index is between min and max index"),
610                            )
611                            .map(|value_index| (idx_in_take as u64, *value_index as u64))
612                    }
613                })
614                .unzip()
615        })
616        .map_err(|_| vortex_err!("Failed to convert index to usize"))?;
617
618    if new_sparse_indices.is_empty() {
619        return Ok(None);
620    }
621
622    let new_sparse_indices = new_sparse_indices.into_array();
623    let values_validity = take_indices_validity.take(&new_sparse_indices)?;
624    Ok(Some((
625        new_sparse_indices,
626        PrimitiveArray::new(value_indices, values_validity).into_array(),
627    )))
628}
629
630/// Filter patches with the provided mask (in flattened space).
631///
632/// The filter mask may contain indices that are non-patched. The return value of this function
633/// is a new set of `Patches` with the indices relative to the provided `mask` rank, and the
634/// patch values.
635fn filter_patches_with_mask<T: ToPrimitive + Copy + Ord>(
636    patch_indices: &[T],
637    offset: usize,
638    patch_values: &dyn Array,
639    mask_indices: &[usize],
640) -> VortexResult<Option<Patches>> {
641    let true_count = mask_indices.len();
642    let mut new_patch_indices = BufferMut::<u64>::with_capacity(true_count);
643    let mut new_mask_indices = Vec::with_capacity(true_count);
644
645    // Attempt to move the window by `STRIDE` elements on each iteration. This assumes that
646    // the patches are relatively sparse compared to the overall mask, and so many indices in the
647    // mask will end up being skipped.
648    const STRIDE: usize = 4;
649
650    let mut mask_idx = 0usize;
651    let mut true_idx = 0usize;
652
653    while mask_idx < patch_indices.len() && true_idx < true_count {
654        // NOTE: we are searching for overlaps between sorted, unaligned indices in `patch_indices`
655        //  and `mask_indices`. We assume that Patches are sparse relative to the global space of
656        //  the mask (which covers both patch and non-patch values of the parent array), and so to
657        //  quickly jump through regions with no overlap, we attempt to move our pointers by STRIDE
658        //  elements on each iteration. If we cannot rule out overlap due to min/max values, we
659        //  fallback to performing a two-way iterator merge.
660        if (mask_idx + STRIDE) < patch_indices.len() && (true_idx + STRIDE) < mask_indices.len() {
661            // Load a vector of each into our registers.
662            let left_min = patch_indices[mask_idx].to_usize().vortex_expect("left_min") - offset;
663            let left_max = patch_indices[mask_idx + STRIDE]
664                .to_usize()
665                .vortex_expect("left_max")
666                - offset;
667            let right_min = mask_indices[true_idx];
668            let right_max = mask_indices[true_idx + STRIDE];
669
670            if left_min > right_max {
671                // Advance right side
672                true_idx += STRIDE;
673                continue;
674            } else if right_min > left_max {
675                mask_idx += STRIDE;
676                continue;
677            } else {
678                // Fallthrough to direct comparison path.
679            }
680        }
681
682        // Two-way sorted iterator merge:
683
684        let left = patch_indices[mask_idx].to_usize().vortex_expect("left") - offset;
685        let right = mask_indices[true_idx];
686
687        match left.cmp(&right) {
688            Ordering::Less => {
689                mask_idx += 1;
690            }
691            Ordering::Greater => {
692                true_idx += 1;
693            }
694            Ordering::Equal => {
695                // Save the mask index as well as the positional index.
696                new_mask_indices.push(mask_idx);
697                new_patch_indices.push(true_idx as u64);
698
699                mask_idx += 1;
700                true_idx += 1;
701            }
702        }
703    }
704
705    if new_mask_indices.is_empty() {
706        return Ok(None);
707    }
708
709    let new_patch_indices = new_patch_indices.into_array();
710    let new_patch_values = filter(
711        patch_values,
712        &Mask::from_indices(patch_values.len(), new_mask_indices),
713    )?;
714
715    Ok(Some(Patches::new(
716        true_count,
717        0,
718        new_patch_indices,
719        new_patch_values,
720    )))
721}
722
723#[cfg(test)]
724mod test {
725    use rstest::{fixture, rstest};
726    use vortex_buffer::buffer;
727    use vortex_mask::Mask;
728
729    use crate::arrays::PrimitiveArray;
730    use crate::patches::Patches;
731    use crate::search_sorted::{SearchResult, SearchSortedSide};
732    use crate::validity::Validity;
733    use crate::{IntoArray, ToCanonical};
734
735    #[test]
736    fn test_filter() {
737        let patches = Patches::new(
738            100,
739            0,
740            buffer![10u32, 11, 20].into_array(),
741            buffer![100, 110, 200].into_array(),
742        );
743
744        let filtered = patches
745            .filter(&Mask::from_indices(100, vec![10, 20, 30]))
746            .unwrap()
747            .unwrap();
748
749        let indices = filtered.indices().to_primitive().unwrap();
750        let values = filtered.values().to_primitive().unwrap();
751        assert_eq!(indices.as_slice::<u64>(), &[0, 1]);
752        assert_eq!(values.as_slice::<i32>(), &[100, 200]);
753    }
754
755    #[fixture]
756    fn patches() -> Patches {
757        Patches::new(
758            20,
759            0,
760            buffer![2u64, 9, 15].into_array(),
761            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
762        )
763    }
764
765    #[rstest]
766    fn search_larger_than(patches: Patches) {
767        let res = patches.search_sorted(66, SearchSortedSide::Left).unwrap();
768        assert_eq!(res, SearchResult::NotFound(16));
769    }
770
771    #[rstest]
772    fn search_less_than(patches: Patches) {
773        let res = patches.search_sorted(22, SearchSortedSide::Left).unwrap();
774        assert_eq!(res, SearchResult::NotFound(2));
775    }
776
777    #[rstest]
778    fn search_found(patches: Patches) {
779        let res = patches.search_sorted(44, SearchSortedSide::Left).unwrap();
780        assert_eq!(res, SearchResult::Found(9));
781    }
782
783    #[rstest]
784    fn search_not_found_right(patches: Patches) {
785        let res = patches.search_sorted(56, SearchSortedSide::Right).unwrap();
786        assert_eq!(res, SearchResult::NotFound(16));
787    }
788
789    #[rstest]
790    fn search_sliced(patches: Patches) {
791        let sliced = patches.slice(7..20).unwrap();
792        assert_eq!(
793            sliced.search_sorted(22, SearchSortedSide::Left).unwrap(),
794            SearchResult::NotFound(2)
795        );
796    }
797
798    #[test]
799    fn search_right() {
800        let patches = Patches::new(
801            6,
802            0,
803            buffer![0u8, 1, 4, 5].into_array(),
804            buffer![-128i8, -98, 8, 50].into_array(),
805        );
806
807        assert_eq!(
808            patches.search_sorted(-98, SearchSortedSide::Right).unwrap(),
809            SearchResult::Found(2)
810        );
811        assert_eq!(
812            patches.search_sorted(50, SearchSortedSide::Right).unwrap(),
813            SearchResult::Found(6),
814        );
815        assert_eq!(
816            patches.search_sorted(7, SearchSortedSide::Right).unwrap(),
817            SearchResult::NotFound(2),
818        );
819        assert_eq!(
820            patches.search_sorted(51, SearchSortedSide::Right).unwrap(),
821            SearchResult::NotFound(6)
822        );
823    }
824
825    #[test]
826    fn search_left() {
827        let patches = Patches::new(
828            20,
829            0,
830            buffer![0u64, 1, 17, 18, 19].into_array(),
831            buffer![11i32, 22, 33, 44, 55].into_array(),
832        );
833        assert_eq!(
834            patches.search_sorted(30, SearchSortedSide::Left).unwrap(),
835            SearchResult::NotFound(2)
836        );
837        assert_eq!(
838            patches.search_sorted(54, SearchSortedSide::Left).unwrap(),
839            SearchResult::NotFound(19)
840        );
841    }
842
843    #[rstest]
844    fn take_with_nulls(patches: Patches) {
845        let taken = patches
846            .take(
847                &PrimitiveArray::new(buffer![9, 0], Validity::from_iter(vec![true, false]))
848                    .into_array(),
849            )
850            .unwrap()
851            .unwrap();
852        let primitive_values = taken.values().to_primitive().unwrap();
853        assert_eq!(taken.array_len(), 2);
854        assert_eq!(primitive_values.as_slice::<i32>(), [44]);
855        assert_eq!(
856            primitive_values.validity_mask(),
857            Mask::from_iter(vec![true])
858        );
859    }
860
861    #[test]
862    fn test_slice() {
863        let values = buffer![15_u32, 135, 13531, 42].into_array();
864        let indices = buffer![10_u64, 11, 50, 100].into_array();
865
866        let patches = Patches::new(101, 0, indices, values);
867
868        let sliced = patches.slice(15..100).unwrap();
869        assert_eq!(sliced.array_len(), 100 - 15);
870        let primitive = sliced.values().to_primitive().unwrap();
871
872        assert_eq!(primitive.as_slice::<u32>(), &[13531]);
873    }
874
875    #[test]
876    fn doubly_sliced() {
877        let values = buffer![15_u32, 135, 13531, 42].into_array();
878        let indices = buffer![10_u64, 11, 50, 100].into_array();
879
880        let patches = Patches::new(101, 0, indices, values);
881
882        let sliced = patches.slice(15..100).unwrap();
883        assert_eq!(sliced.array_len(), 100 - 15);
884        let primitive = sliced.values().to_primitive().unwrap();
885
886        assert_eq!(primitive.as_slice::<u32>(), &[13531]);
887
888        let doubly_sliced = sliced.slice(35..36).unwrap();
889        let primitive_doubly_sliced = doubly_sliced.values().to_primitive().unwrap();
890
891        assert_eq!(primitive_doubly_sliced.as_slice::<u32>(), &[13531]);
892    }
893
894    #[test]
895    fn test_mask_all_true() {
896        let patches = Patches::new(
897            10,
898            0,
899            buffer![2u64, 5, 8].into_array(),
900            buffer![100i32, 200, 300].into_array(),
901        );
902
903        let mask = Mask::new_true(10);
904        let masked = patches.mask(&mask).unwrap();
905        assert!(masked.is_none());
906    }
907
908    #[test]
909    fn test_mask_all_false() {
910        let patches = Patches::new(
911            10,
912            0,
913            buffer![2u64, 5, 8].into_array(),
914            buffer![100i32, 200, 300].into_array(),
915        );
916
917        let mask = Mask::new_false(10);
918        let masked = patches.mask(&mask).unwrap().unwrap();
919
920        // No patch values should be masked
921        let masked_values = masked.values().to_primitive().unwrap();
922        assert_eq!(masked_values.as_slice::<i32>(), &[100, 200, 300]);
923        assert!(masked_values.is_valid(0));
924        assert!(masked_values.is_valid(1));
925        assert!(masked_values.is_valid(2));
926
927        // Indices should remain unchanged
928        let indices = masked.indices().to_primitive().unwrap();
929        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
930    }
931
932    #[test]
933    fn test_mask_partial() {
934        let patches = Patches::new(
935            10,
936            0,
937            buffer![2u64, 5, 8].into_array(),
938            buffer![100i32, 200, 300].into_array(),
939        );
940
941        // Mask that removes patches at indices 2 and 8 (but not 5)
942        let mask = Mask::from_iter([
943            false, false, true, false, false, false, false, false, true, false,
944        ]);
945        let masked = patches.mask(&mask).unwrap().unwrap();
946
947        // Only the patch at index 5 should remain
948        let masked_values = masked.values().to_primitive().unwrap();
949        assert_eq!(masked_values.len(), 1);
950        assert_eq!(masked_values.as_slice::<i32>(), &[200]);
951
952        // Only index 5 should remain
953        let indices = masked.indices().to_primitive().unwrap();
954        assert_eq!(indices.as_slice::<u64>(), &[5]);
955    }
956
957    #[test]
958    fn test_mask_with_offset() {
959        let patches = Patches::new(
960            10,
961            5,                                  // offset
962            buffer![7u64, 10, 13].into_array(), // actual indices are 2, 5, 8
963            buffer![100i32, 200, 300].into_array(),
964        );
965
966        // Mask that sets actual index 2 to null
967        let mask = Mask::from_iter([
968            false, false, true, false, false, false, false, false, false, false,
969        ]);
970
971        let masked = patches.mask(&mask).unwrap().unwrap();
972        assert_eq!(masked.array_len(), 10);
973        assert_eq!(masked.offset(), 5);
974        let indices = masked.indices().to_primitive().unwrap();
975        assert_eq!(indices.as_slice::<u64>(), &[10, 13]);
976        let masked_values = masked.values().to_primitive().unwrap();
977        assert_eq!(masked_values.as_slice::<i32>(), &[200, 300]);
978    }
979
980    #[test]
981    fn test_mask_nullable_values() {
982        let patches = Patches::new(
983            10,
984            0,
985            buffer![2u64, 5, 8].into_array(),
986            PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
987        );
988
989        // Test masking removes patch at index 2
990        let mask = Mask::from_iter([
991            false, false, true, false, false, false, false, false, false, false,
992        ]);
993        let masked = patches.mask(&mask).unwrap().unwrap();
994
995        // Patches at indices 5 and 8 should remain
996        let indices = masked.indices().to_primitive().unwrap();
997        assert_eq!(indices.as_slice::<u64>(), &[5, 8]);
998
999        // Values should be the null and 300
1000        let masked_values = masked.values().to_primitive().unwrap();
1001        assert_eq!(masked_values.len(), 2);
1002        assert!(!masked_values.is_valid(0)); // the null value at index 5
1003        assert!(masked_values.is_valid(1)); // the 300 value at index 8
1004        assert_eq!(i32::try_from(&masked_values.scalar_at(1)).unwrap(), 300i32);
1005    }
1006
1007    #[test]
1008    fn test_filter_keep_all() {
1009        let patches = Patches::new(
1010            10,
1011            0,
1012            buffer![2u64, 5, 8].into_array(),
1013            buffer![100i32, 200, 300].into_array(),
1014        );
1015
1016        // Keep all indices (mask with indices 0-9)
1017        let mask = Mask::from_indices(10, (0..10).collect());
1018        let filtered = patches.filter(&mask).unwrap().unwrap();
1019
1020        let indices = filtered.indices().to_primitive().unwrap();
1021        let values = filtered.values().to_primitive().unwrap();
1022        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
1023        assert_eq!(values.as_slice::<i32>(), &[100, 200, 300]);
1024    }
1025
1026    #[test]
1027    fn test_filter_none() {
1028        let patches = Patches::new(
1029            10,
1030            0,
1031            buffer![2u64, 5, 8].into_array(),
1032            buffer![100i32, 200, 300].into_array(),
1033        );
1034
1035        // Filter out all (empty mask means keep nothing)
1036        let mask = Mask::from_indices(10, vec![]);
1037        let filtered = patches.filter(&mask).unwrap();
1038        assert!(filtered.is_none());
1039    }
1040
1041    #[test]
1042    fn test_filter_with_indices() {
1043        let patches = Patches::new(
1044            10,
1045            0,
1046            buffer![2u64, 5, 8].into_array(),
1047            buffer![100i32, 200, 300].into_array(),
1048        );
1049
1050        // Keep indices 2, 5, 9 (so patches at 2 and 5 remain)
1051        let mask = Mask::from_indices(10, vec![2, 5, 9]);
1052        let filtered = patches.filter(&mask).unwrap().unwrap();
1053
1054        let indices = filtered.indices().to_primitive().unwrap();
1055        let values = filtered.values().to_primitive().unwrap();
1056        assert_eq!(indices.as_slice::<u64>(), &[0, 1]); // Adjusted indices
1057        assert_eq!(values.as_slice::<i32>(), &[100, 200]);
1058    }
1059
1060    #[test]
1061    fn test_slice_full_range() {
1062        let patches = Patches::new(
1063            10,
1064            0,
1065            buffer![2u64, 5, 8].into_array(),
1066            buffer![100i32, 200, 300].into_array(),
1067        );
1068
1069        let sliced = patches.slice(0..10).unwrap();
1070
1071        let indices = sliced.indices().to_primitive().unwrap();
1072        let values = sliced.values().to_primitive().unwrap();
1073        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
1074        assert_eq!(values.as_slice::<i32>(), &[100, 200, 300]);
1075    }
1076
1077    #[test]
1078    fn test_slice_partial() {
1079        let patches = Patches::new(
1080            10,
1081            0,
1082            buffer![2u64, 5, 8].into_array(),
1083            buffer![100i32, 200, 300].into_array(),
1084        );
1085
1086        // Slice from 3 to 8 (includes patch at 5)
1087        let sliced = patches.slice(3..8).unwrap();
1088
1089        let indices = sliced.indices().to_primitive().unwrap();
1090        let values = sliced.values().to_primitive().unwrap();
1091        assert_eq!(indices.as_slice::<u64>(), &[5]); // Index stays the same
1092        assert_eq!(values.as_slice::<i32>(), &[200]);
1093        assert_eq!(sliced.array_len(), 5); // 8 - 3 = 5
1094        assert_eq!(sliced.offset(), 3); // New offset
1095    }
1096
1097    #[test]
1098    fn test_slice_no_patches() {
1099        let patches = Patches::new(
1100            10,
1101            0,
1102            buffer![2u64, 5, 8].into_array(),
1103            buffer![100i32, 200, 300].into_array(),
1104        );
1105
1106        // Slice from 6 to 7 (no patches in this range)
1107        let sliced = patches.slice(6..7);
1108        assert!(sliced.is_none());
1109    }
1110
1111    #[test]
1112    fn test_slice_with_offset() {
1113        let patches = Patches::new(
1114            10,
1115            5,                                  // offset
1116            buffer![7u64, 10, 13].into_array(), // actual indices are 2, 5, 8
1117            buffer![100i32, 200, 300].into_array(),
1118        );
1119
1120        // Slice from 3 to 8 (includes patch at actual index 5)
1121        let sliced = patches.slice(3..8).unwrap();
1122
1123        let indices = sliced.indices().to_primitive().unwrap();
1124        let values = sliced.values().to_primitive().unwrap();
1125        assert_eq!(indices.as_slice::<u64>(), &[10]); // Index stays the same (offset + 5 = 10)
1126        assert_eq!(values.as_slice::<i32>(), &[200]);
1127        assert_eq!(sliced.offset(), 8); // New offset = 5 + 3
1128    }
1129
1130    #[test]
1131    fn test_patch_values() {
1132        let patches = Patches::new(
1133            10,
1134            0,
1135            buffer![2u64, 5, 8].into_array(),
1136            buffer![100i32, 200, 300].into_array(),
1137        );
1138
1139        let values = patches.values().to_primitive().unwrap();
1140        assert_eq!(i32::try_from(&values.scalar_at(0)).unwrap(), 100i32);
1141        assert_eq!(i32::try_from(&values.scalar_at(1)).unwrap(), 200i32);
1142        assert_eq!(i32::try_from(&values.scalar_at(2)).unwrap(), 300i32);
1143    }
1144
1145    #[test]
1146    fn test_indices_range() {
1147        let patches = Patches::new(
1148            10,
1149            0,
1150            buffer![2u64, 5, 8].into_array(),
1151            buffer![100i32, 200, 300].into_array(),
1152        );
1153
1154        assert_eq!(patches.min_index(), 2);
1155        assert_eq!(patches.max_index(), 8);
1156    }
1157
1158    #[test]
1159    fn test_search_index() {
1160        let patches = Patches::new(
1161            10,
1162            0,
1163            buffer![2u64, 5, 8].into_array(),
1164            buffer![100i32, 200, 300].into_array(),
1165        );
1166
1167        // Search for exact indices
1168        assert_eq!(patches.search_index(2), SearchResult::Found(0));
1169        assert_eq!(patches.search_index(5), SearchResult::Found(1));
1170        assert_eq!(patches.search_index(8), SearchResult::Found(2));
1171
1172        // Search for non-patch indices
1173        assert_eq!(patches.search_index(0), SearchResult::NotFound(0));
1174        assert_eq!(patches.search_index(3), SearchResult::NotFound(1));
1175        assert_eq!(patches.search_index(6), SearchResult::NotFound(2));
1176        assert_eq!(patches.search_index(9), SearchResult::NotFound(3));
1177    }
1178
1179    #[test]
1180    fn test_mask_boundary_patches() {
1181        // Test masking patches at array boundaries
1182        let patches = Patches::new(
1183            10,
1184            0,
1185            buffer![0u64, 9].into_array(),
1186            buffer![100i32, 200].into_array(),
1187        );
1188
1189        let mask = Mask::from_iter([
1190            true, false, false, false, false, false, false, false, false, false,
1191        ]);
1192        let masked = patches.mask(&mask).unwrap();
1193        assert!(masked.is_some());
1194        let masked = masked.unwrap();
1195        let indices = masked.indices().to_primitive().unwrap();
1196        assert_eq!(indices.as_slice::<u64>(), &[9]);
1197        let values = masked.values().to_primitive().unwrap();
1198        assert_eq!(values.as_slice::<i32>(), &[200]);
1199    }
1200
1201    #[test]
1202    fn test_mask_all_patches_removed() {
1203        // Test when all patches are masked out
1204        let patches = Patches::new(
1205            10,
1206            0,
1207            buffer![2u64, 5, 8].into_array(),
1208            buffer![100i32, 200, 300].into_array(),
1209        );
1210
1211        // Mask that removes all patches
1212        let mask = Mask::from_iter([
1213            false, false, true, false, false, true, false, false, true, false,
1214        ]);
1215        let masked = patches.mask(&mask).unwrap();
1216        assert!(masked.is_none());
1217    }
1218
1219    #[test]
1220    fn test_mask_no_patches_removed() {
1221        // Test when no patches are masked
1222        let patches = Patches::new(
1223            10,
1224            0,
1225            buffer![2u64, 5, 8].into_array(),
1226            buffer![100i32, 200, 300].into_array(),
1227        );
1228
1229        // Mask that doesn't affect any patches
1230        let mask = Mask::from_iter([
1231            true, false, false, true, false, false, true, false, false, true,
1232        ]);
1233        let masked = patches.mask(&mask).unwrap().unwrap();
1234
1235        let indices = masked.indices().to_primitive().unwrap();
1236        assert_eq!(indices.as_slice::<u64>(), &[2, 5, 8]);
1237        let values = masked.values().to_primitive().unwrap();
1238        assert_eq!(values.as_slice::<i32>(), &[100, 200, 300]);
1239    }
1240
1241    #[test]
1242    fn test_mask_single_patch() {
1243        // Test with a single patch
1244        let patches = Patches::new(
1245            5,
1246            0,
1247            buffer![2u64].into_array(),
1248            buffer![42i32].into_array(),
1249        );
1250
1251        // Mask that removes the single patch
1252        let mask = Mask::from_iter([false, false, true, false, false]);
1253        let masked = patches.mask(&mask).unwrap();
1254        assert!(masked.is_none());
1255
1256        // Mask that keeps the single patch
1257        let mask = Mask::from_iter([true, false, false, true, false]);
1258        let masked = patches.mask(&mask).unwrap().unwrap();
1259        let indices = masked.indices().to_primitive().unwrap();
1260        assert_eq!(indices.as_slice::<u64>(), &[2]);
1261    }
1262
1263    #[test]
1264    fn test_mask_contiguous_patches() {
1265        // Test with contiguous patches
1266        let patches = Patches::new(
1267            10,
1268            0,
1269            buffer![3u64, 4, 5, 6].into_array(),
1270            buffer![100i32, 200, 300, 400].into_array(),
1271        );
1272
1273        // Mask that removes middle patches
1274        let mask = Mask::from_iter([
1275            false, false, false, false, true, true, false, false, false, false,
1276        ]);
1277        let masked = patches.mask(&mask).unwrap().unwrap();
1278
1279        let indices = masked.indices().to_primitive().unwrap();
1280        assert_eq!(indices.as_slice::<u64>(), &[3, 6]);
1281        let values = masked.values().to_primitive().unwrap();
1282        assert_eq!(values.as_slice::<i32>(), &[100, 400]);
1283    }
1284
1285    #[test]
1286    fn test_mask_with_large_offset() {
1287        // Test with a large offset that shifts all indices
1288        let patches = Patches::new(
1289            20,
1290            15,
1291            buffer![16u64, 17, 19].into_array(), // actual indices are 1, 2, 4
1292            buffer![100i32, 200, 300].into_array(),
1293        );
1294
1295        // Mask that removes the patch at actual index 2
1296        let mask = Mask::from_iter([
1297            false, false, true, false, false, false, false, false, false, false, false, false,
1298            false, false, false, false, false, false, false, false,
1299        ]);
1300        let masked = patches.mask(&mask).unwrap().unwrap();
1301
1302        let indices = masked.indices().to_primitive().unwrap();
1303        assert_eq!(indices.as_slice::<u64>(), &[16, 19]);
1304        let values = masked.values().to_primitive().unwrap();
1305        assert_eq!(values.as_slice::<i32>(), &[100, 300]);
1306    }
1307
1308    #[test]
1309    #[should_panic(expected = "Filter mask length 5 does not match array length 10")]
1310    fn test_mask_wrong_length() {
1311        let patches = Patches::new(
1312            10,
1313            0,
1314            buffer![2u64, 5, 8].into_array(),
1315            buffer![100i32, 200, 300].into_array(),
1316        );
1317
1318        // Mask with wrong length
1319        let mask = Mask::from_iter([false, false, true, false, false]);
1320        let _ = patches.mask(&mask).unwrap();
1321    }
1322}