vortex_array/
patches.rs

1use std::cmp::Ordering;
2use std::fmt::Debug;
3use std::hash::Hash;
4
5use itertools::Itertools as _;
6use num_traits::ToPrimitive;
7use serde::{Deserialize, Serialize};
8use vortex_buffer::BufferMut;
9use vortex_dtype::Nullability::NonNullable;
10use vortex_dtype::{DType, NativePType, PType, match_each_integer_ptype};
11use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
12use vortex_mask::{AllOr, Mask};
13use vortex_scalar::Scalar;
14
15use crate::aliases::hash_map::HashMap;
16use crate::arrays::PrimitiveArray;
17use crate::compute::{
18    SearchResult, SearchSortedSide, filter, scalar_at, search_sorted, search_sorted_usize,
19    search_sorted_usize_many, slice, take, try_cast,
20};
21use crate::variants::PrimitiveArrayTrait;
22use crate::{Array, ArrayRef, IntoArray, ToCanonical};
23
24#[derive(
25    Copy,
26    Clone,
27    Debug,
28    Serialize,
29    Deserialize,
30    rkyv::Archive,
31    rkyv::Serialize,
32    rkyv::Deserialize,
33    rkyv::bytecheck::CheckBytes,
34)]
35#[bytecheck(crate = rkyv::bytecheck)]
36#[repr(C)]
37pub struct PatchesMetadata {
38    len: usize,
39    offset: usize,
40    indices_ptype: PType,
41}
42
43impl PatchesMetadata {
44    pub fn new(len: usize, offset: usize, indices_ptype: PType) -> Self {
45        Self {
46            len,
47            offset,
48            indices_ptype,
49        }
50    }
51
52    #[inline]
53    pub fn len(&self) -> usize {
54        self.len
55    }
56
57    #[inline]
58    pub fn is_empty(&self) -> bool {
59        self.len == 0
60    }
61
62    #[inline]
63    pub fn offset(&self) -> usize {
64        self.offset
65    }
66
67    #[inline]
68    pub fn indices_dtype(&self) -> DType {
69        assert!(
70            self.indices_ptype.is_unsigned_int(),
71            "Patch indices must be unsigned integers"
72        );
73        DType::Primitive(self.indices_ptype, NonNullable)
74    }
75}
76
77/// A helper for working with patched arrays.
78#[derive(Debug, Clone)]
79pub struct Patches {
80    array_len: usize,
81    offset: usize,
82    indices: ArrayRef,
83    values: ArrayRef,
84}
85
86impl Patches {
87    pub fn new(array_len: usize, offset: usize, indices: ArrayRef, values: ArrayRef) -> Self {
88        assert_eq!(
89            indices.len(),
90            values.len(),
91            "Patch indices and values must have the same length"
92        );
93        assert!(
94            indices.dtype().is_unsigned_int(),
95            "Patch indices must be unsigned integers"
96        );
97        assert!(
98            indices.len() <= array_len,
99            "Patch indices must be shorter than the array length"
100        );
101        assert!(!indices.is_empty(), "Patch indices must not be empty");
102        let max = usize::try_from(
103            &scalar_at(&indices, indices.len() - 1).vortex_expect("indices are not empty"),
104        )
105        .vortex_expect("indices must be a number");
106        assert!(
107            max - offset < array_len,
108            "Patch indices {:?}, offset {} are longer than the array length {}",
109            max,
110            offset,
111            array_len
112        );
113        Self::new_unchecked(array_len, offset, indices, values)
114    }
115
116    /// Construct new patches without validating any of the arguments
117    ///
118    /// # Safety
119    ///
120    /// Users have to assert that
121    /// * Indices and values have the same length
122    /// * Indices is an unsigned integer type
123    /// * Indices must be sorted
124    /// * Last value in indices is smaller than array_len
125    pub fn new_unchecked(
126        array_len: usize,
127        offset: usize,
128        indices: ArrayRef,
129        values: ArrayRef,
130    ) -> Self {
131        Self {
132            array_len,
133            offset,
134            indices,
135            values,
136        }
137    }
138
139    // TODO(ngates): remove this...
140    pub fn into_parts(self) -> (usize, usize, ArrayRef, ArrayRef) {
141        (self.array_len, self.offset, self.indices, self.values)
142    }
143
144    pub fn array_len(&self) -> usize {
145        self.array_len
146    }
147
148    pub fn num_patches(&self) -> usize {
149        self.indices.len()
150    }
151
152    pub fn dtype(&self) -> &DType {
153        self.values.dtype()
154    }
155
156    pub fn indices(&self) -> &ArrayRef {
157        &self.indices
158    }
159
160    pub fn into_indices(self) -> ArrayRef {
161        self.indices
162    }
163
164    pub fn values(&self) -> &ArrayRef {
165        &self.values
166    }
167
168    pub fn into_values(self) -> ArrayRef {
169        self.values
170    }
171
172    pub fn offset(&self) -> usize {
173        self.offset
174    }
175
176    pub fn indices_ptype(&self) -> PType {
177        PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
178    }
179
180    pub fn to_metadata(&self, len: usize, dtype: &DType) -> VortexResult<PatchesMetadata> {
181        if self.indices.len() > len {
182            vortex_bail!(
183                "Patch indices {} are longer than the array length {}",
184                self.indices.len(),
185                len
186            );
187        }
188        if self.values.dtype() != dtype {
189            vortex_bail!(
190                "Patch values dtype {} does not match array dtype {}",
191                self.values.dtype(),
192                dtype
193            );
194        }
195        Ok(PatchesMetadata {
196            len: self.indices.len(),
197            offset: self.offset,
198            indices_ptype: PType::try_from(self.indices.dtype()).vortex_expect("primitive indices"),
199        })
200    }
201
202    pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
203        Ok(Self::new_unchecked(
204            self.array_len,
205            self.offset,
206            self.indices,
207            try_cast(&self.values, values_dtype)?,
208        ))
209    }
210
211    /// Get the patched value at a given index if it exists.
212    pub fn get_patched(&self, index: usize) -> VortexResult<Option<Scalar>> {
213        if let Some(patch_idx) = self.search_index(index)?.to_found() {
214            scalar_at(self.values(), patch_idx).map(Some)
215        } else {
216            Ok(None)
217        }
218    }
219
220    /// Return the insertion point of [index] in the [Self::indices].
221    fn search_index(&self, index: usize) -> VortexResult<SearchResult> {
222        search_sorted_usize(&self.indices, index + self.offset, SearchSortedSide::Left)
223    }
224
225    /// Return the search_sorted result for the given target re-mapped into the original indices.
226    pub fn search_sorted<T: Into<Scalar>>(
227        &self,
228        target: T,
229        side: SearchSortedSide,
230    ) -> VortexResult<SearchResult> {
231        search_sorted(self.values(), target.into(), side).and_then(|sr| {
232            let sidx = sr.to_offsets_index(self.indices().len());
233            let index = usize::try_from(&scalar_at(self.indices(), sidx)?)? - self.offset;
234            Ok(match sr {
235                // If we reached the end of patched values when searching then the result is one after the last patch index
236                SearchResult::Found(i) => SearchResult::Found(if i == self.indices().len() {
237                    index + 1
238                } else {
239                    index
240                }),
241                // If the result is NotFound we should return index that's one after the nearest not found index for the corresponding value
242                SearchResult::NotFound(i) => {
243                    SearchResult::NotFound(if i == 0 { index } else { index + 1 })
244                }
245            })
246        })
247    }
248
249    /// Returns the minimum patch index
250    pub fn min_index(&self) -> VortexResult<usize> {
251        Ok(usize::try_from(&scalar_at(self.indices(), 0)?)? - self.offset)
252    }
253
254    /// Returns the maximum patch index
255    pub fn max_index(&self) -> VortexResult<usize> {
256        Ok(usize::try_from(&scalar_at(self.indices(), self.indices().len() - 1)?)? - self.offset)
257    }
258
259    /// Filter the patches by a mask, resulting in new patches for the filtered array.
260    pub fn filter(&self, mask: &Mask) -> VortexResult<Option<Self>> {
261        match mask.indices() {
262            AllOr::All => Ok(Some(self.clone())),
263            AllOr::None => Ok(None),
264            AllOr::Some(mask_indices) => {
265                let flat_indices = self.indices().to_primitive()?;
266                match_each_integer_ptype!(flat_indices.ptype(), |$I| {
267                    filter_patches_with_mask(
268                        flat_indices.as_slice::<$I>(),
269                        self.offset(),
270                        self.values(),
271                        mask_indices,
272                    )
273                })
274            }
275        }
276    }
277
278    /// Slice the patches by a range of the patched array.
279    pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Option<Self>> {
280        let patch_start = self.search_index(start)?.to_index();
281        let patch_stop = self.search_index(stop)?.to_index();
282
283        if patch_start == patch_stop {
284            return Ok(None);
285        }
286
287        // Slice out the values and indices
288        let values = slice(self.values(), patch_start, patch_stop)?;
289        let indices = slice(self.indices(), patch_start, patch_stop)?;
290
291        Ok(Some(Self::new(
292            stop - start,
293            start + self.offset(),
294            indices,
295            values,
296        )))
297    }
298
299    // https://docs.google.com/spreadsheets/d/1D9vBZ1QJ6mwcIvV5wIL0hjGgVchcEnAyhvitqWu2ugU
300    const PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN: f64 = 5.0;
301
302    fn is_map_faster_than_search(&self, take_indices: &PrimitiveArray) -> bool {
303        (self.num_patches() as f64 / take_indices.len() as f64)
304            < Self::PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN
305    }
306
307    /// Take the indices from the patches.
308    pub fn take(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
309        if take_indices.is_empty() {
310            return Ok(None);
311        }
312        let take_indices = take_indices.to_primitive()?;
313        if self.is_map_faster_than_search(&take_indices) {
314            self.take_map(take_indices)
315        } else {
316            self.take_search(take_indices)
317        }
318    }
319
320    pub fn take_search(&self, take_indices: PrimitiveArray) -> VortexResult<Option<Self>> {
321        let new_length = take_indices.len();
322
323        let Some((new_indices, values_indices)) = match_each_integer_ptype!(take_indices.ptype(), |$I| {
324            take_search::<$I>(self.indices(), take_indices, self.offset())?
325        }) else {
326            return Ok(None);
327        };
328
329        Ok(Some(Self::new(
330            new_length,
331            0,
332            new_indices,
333            take(self.values(), &values_indices)?,
334        )))
335    }
336
337    pub fn take_map(&self, take_indices: PrimitiveArray) -> VortexResult<Option<Self>> {
338        let indices = self.indices.to_primitive()?;
339        let new_length = take_indices.len();
340
341        let Some((new_sparse_indices, value_indices)) = match_each_integer_ptype!(self.indices_ptype(), |$INDICES| {
342            match_each_integer_ptype!(take_indices.ptype(), |$TAKE_INDICES| {
343                take_map::<_, $TAKE_INDICES>(indices.as_slice::<$INDICES>(), take_indices, self.offset(), self.min_index()?, self.max_index()?)?
344            })
345        }) else {
346            return Ok(None);
347        };
348
349        Ok(Some(Patches::new(
350            new_length,
351            0,
352            new_sparse_indices,
353            take(self.values(), &value_indices)?,
354        )))
355    }
356
357    pub fn map_values<F>(self, f: F) -> VortexResult<Self>
358    where
359        F: FnOnce(ArrayRef) -> VortexResult<ArrayRef>,
360    {
361        let values = f(self.values)?;
362        if self.indices.len() != values.len() {
363            vortex_bail!(
364                "map_values must preserve length: expected {} received {}",
365                self.indices.len(),
366                values.len()
367            )
368        }
369        Ok(Self::new(self.array_len, self.offset, self.indices, values))
370    }
371}
372
373fn take_search<T: NativePType + TryFrom<usize>>(
374    indices: &dyn Array,
375    take_indices: PrimitiveArray,
376    indices_offset: usize,
377) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
378where
379    usize: TryFrom<T>,
380    VortexError: From<<usize as TryFrom<T>>::Error>,
381{
382    let take_indices_validity = take_indices.validity();
383    let take_indices = take_indices
384        .as_slice::<T>()
385        .iter()
386        .copied()
387        .map(usize::try_from)
388        .map_ok(|idx| idx + indices_offset)
389        .collect::<Result<Vec<_>, _>>()?;
390
391    let (values_indices, new_indices): (BufferMut<u64>, BufferMut<u64>) =
392        search_sorted_usize_many(indices, &take_indices, SearchSortedSide::Left)?
393            .iter()
394            .enumerate()
395            .filter_map(|(idx_in_take, search_result)| {
396                search_result
397                    .to_found()
398                    .map(|patch_idx| (patch_idx as u64, idx_in_take as u64))
399            })
400            .unzip();
401
402    if new_indices.is_empty() {
403        return Ok(None);
404    }
405
406    let new_indices = new_indices.into_array();
407    let values_validity = take_indices_validity.take(&new_indices)?;
408    Ok(Some((
409        new_indices,
410        PrimitiveArray::new(values_indices, values_validity).into_array(),
411    )))
412}
413
414fn take_map<I: NativePType + Hash + Eq + TryFrom<usize>, T: NativePType>(
415    indices: &[I],
416    take_indices: PrimitiveArray,
417    indices_offset: usize,
418    min_index: usize,
419    max_index: usize,
420) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
421where
422    usize: TryFrom<T>,
423    VortexError: From<<I as TryFrom<usize>>::Error>,
424{
425    let take_indices_validity = take_indices.validity();
426    let take_indices = take_indices.as_slice::<T>();
427    let offset_i = I::try_from(indices_offset)?;
428
429    let sparse_index_to_value_index: HashMap<I, usize> = indices
430        .iter()
431        .copied()
432        .map(|idx| idx - offset_i)
433        .enumerate()
434        .map(|(value_index, sparse_index)| (sparse_index, value_index))
435        .collect();
436    let (new_sparse_indices, value_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
437        .iter()
438        .copied()
439        .map(usize::try_from)
440        .process_results(|iter| {
441            iter.enumerate()
442                .filter(|(_, ti)| *ti >= min_index && *ti <= max_index)
443                .filter_map(|(new_sparse_index, take_sparse_index)| {
444                    sparse_index_to_value_index
445                        .get(
446                            &I::try_from(take_sparse_index)
447                                .vortex_expect("take_sparse_index is between min and max index"),
448                        )
449                        .map(|value_index| (new_sparse_index as u64, *value_index as u64))
450                })
451                .unzip()
452        })
453        .map_err(|_| vortex_err!("Failed to convert index to usize"))?;
454
455    if new_sparse_indices.is_empty() {
456        return Ok(None);
457    }
458
459    let new_sparse_indices = new_sparse_indices.into_array();
460    let values_validity = take_indices_validity.take(&new_sparse_indices)?;
461    Ok(Some((
462        new_sparse_indices,
463        PrimitiveArray::new(value_indices, values_validity).into_array(),
464    )))
465}
466
467/// Filter patches with the provided mask (in flattened space).
468///
469/// The filter mask may contain indices that are non-patched. The return value of this function
470/// is a new set of `Patches` with the indices relative to the provided `mask` rank, and the
471/// patch values.
472fn filter_patches_with_mask<T: ToPrimitive + Copy + Ord>(
473    patch_indices: &[T],
474    offset: usize,
475    patch_values: &dyn Array,
476    mask_indices: &[usize],
477) -> VortexResult<Option<Patches>> {
478    let true_count = mask_indices.len();
479    let mut new_patch_indices = BufferMut::<u64>::with_capacity(true_count);
480    let mut new_mask_indices = Vec::with_capacity(true_count);
481
482    // Attempt to move the window by `STRIDE` elements on each iteration. This assumes that
483    // the patches are relatively sparse compared to the overall mask, and so many indices in the
484    // mask will end up being skipped.
485    const STRIDE: usize = 4;
486
487    let mut mask_idx = 0usize;
488    let mut true_idx = 0usize;
489
490    while mask_idx < patch_indices.len() && true_idx < true_count {
491        // NOTE: we are searching for overlaps between sorted, unaligned indices in `patch_indices`
492        //  and `mask_indices`. We assume that Patches are sparse relative to the global space of
493        //  the mask (which covers both patch and non-patch values of the parent array), and so to
494        //  quickly jump through regions with no overlap, we attempt to move our pointers by STRIDE
495        //  elements on each iteration. If we cannot rule out overlap due to min/max values, we
496        //  fallback to performing a two-way iterator merge.
497        if (mask_idx + STRIDE) < patch_indices.len() && (true_idx + STRIDE) < mask_indices.len() {
498            // Load a vector of each into our registers.
499            let left_min = patch_indices[mask_idx].to_usize().vortex_expect("left_min") - offset;
500            let left_max = patch_indices[mask_idx + STRIDE]
501                .to_usize()
502                .vortex_expect("left_max")
503                - offset;
504            let right_min = mask_indices[true_idx];
505            let right_max = mask_indices[true_idx + STRIDE];
506
507            if left_min > right_max {
508                // Advance right side
509                true_idx += STRIDE;
510                continue;
511            } else if right_min > left_max {
512                mask_idx += STRIDE;
513                continue;
514            } else {
515                // Fallthrough to direct comparison path.
516            }
517        }
518
519        // Two-way sorted iterator merge:
520
521        let left = patch_indices[mask_idx].to_usize().vortex_expect("left") - offset;
522        let right = mask_indices[true_idx];
523
524        match left.cmp(&right) {
525            Ordering::Less => {
526                mask_idx += 1;
527            }
528            Ordering::Greater => {
529                true_idx += 1;
530            }
531            Ordering::Equal => {
532                // Save the mask index as well as the positional index.
533                new_mask_indices.push(mask_idx);
534                new_patch_indices.push(true_idx as u64);
535
536                mask_idx += 1;
537                true_idx += 1;
538            }
539        }
540    }
541
542    if new_mask_indices.is_empty() {
543        return Ok(None);
544    }
545
546    let new_patch_indices = new_patch_indices.into_array();
547    let new_patch_values = filter(
548        patch_values,
549        &Mask::from_indices(patch_values.len(), new_mask_indices),
550    )?;
551
552    Ok(Some(Patches::new(
553        true_count,
554        0,
555        new_patch_indices,
556        new_patch_values,
557    )))
558}
559
560#[cfg(test)]
561mod test {
562    use rstest::{fixture, rstest};
563    use vortex_buffer::buffer;
564    use vortex_mask::Mask;
565
566    use crate::array::Array;
567    use crate::arrays::PrimitiveArray;
568    use crate::compute::{SearchResult, SearchSortedSide};
569    use crate::patches::Patches;
570    use crate::validity::Validity;
571    use crate::{IntoArray, ToCanonical};
572
573    #[test]
574    fn test_filter() {
575        let patches = Patches::new(
576            100,
577            0,
578            buffer![10u32, 11, 20].into_array(),
579            buffer![100, 110, 200].into_array(),
580        );
581
582        let filtered = patches
583            .filter(&Mask::from_indices(100, vec![10, 20, 30]))
584            .unwrap()
585            .unwrap();
586
587        let indices = filtered.indices().to_primitive().unwrap();
588        let values = filtered.values().to_primitive().unwrap();
589        assert_eq!(indices.as_slice::<u64>(), &[0, 1]);
590        assert_eq!(values.as_slice::<i32>(), &[100, 200]);
591    }
592
593    #[fixture]
594    fn patches() -> Patches {
595        Patches::new(
596            20,
597            0,
598            buffer![2u64, 9, 15].into_array(),
599            PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
600        )
601    }
602
603    #[rstest]
604    fn search_larger_than(patches: Patches) {
605        let res = patches.search_sorted(66, SearchSortedSide::Left).unwrap();
606        assert_eq!(res, SearchResult::NotFound(16));
607    }
608
609    #[rstest]
610    fn search_less_than(patches: Patches) {
611        let res = patches.search_sorted(22, SearchSortedSide::Left).unwrap();
612        assert_eq!(res, SearchResult::NotFound(2));
613    }
614
615    #[rstest]
616    fn search_found(patches: Patches) {
617        let res = patches.search_sorted(44, SearchSortedSide::Left).unwrap();
618        assert_eq!(res, SearchResult::Found(9));
619    }
620
621    #[rstest]
622    fn search_not_found_right(patches: Patches) {
623        let res = patches.search_sorted(56, SearchSortedSide::Right).unwrap();
624        assert_eq!(res, SearchResult::NotFound(16));
625    }
626
627    #[rstest]
628    fn search_sliced(patches: Patches) {
629        let sliced = patches.slice(7, 20).unwrap().unwrap();
630        assert_eq!(
631            sliced.search_sorted(22, SearchSortedSide::Left).unwrap(),
632            SearchResult::NotFound(2)
633        );
634    }
635
636    #[test]
637    fn search_right() {
638        let patches = Patches::new(
639            2,
640            0,
641            buffer![0u64].into_array(),
642            PrimitiveArray::new(buffer![0u8], Validity::AllValid).into_array(),
643        );
644
645        assert_eq!(
646            patches.search_sorted(0, SearchSortedSide::Right).unwrap(),
647            SearchResult::Found(1)
648        );
649        assert_eq!(
650            patches.search_sorted(1, SearchSortedSide::Right).unwrap(),
651            SearchResult::NotFound(1)
652        );
653    }
654
655    #[rstest]
656    fn take_wit_nulls(patches: Patches) {
657        let taken = patches
658            .take(
659                &PrimitiveArray::new(buffer![9, 0], Validity::from_iter(vec![true, false]))
660                    .into_array(),
661            )
662            .unwrap()
663            .unwrap();
664        let primitive_values = taken.values().to_primitive().unwrap();
665        assert_eq!(taken.array_len(), 2);
666        assert_eq!(primitive_values.as_slice::<i32>(), [44]);
667        assert_eq!(
668            primitive_values.validity_mask().unwrap(),
669            Mask::from_iter(vec![true])
670        );
671    }
672}