1use std::cmp::Ordering;
2use std::fmt::Debug;
3use std::hash::Hash;
4
5use itertools::Itertools as _;
6use num_traits::{NumCast, ToPrimitive};
7use serde::{Deserialize, Serialize};
8use vortex_buffer::BufferMut;
9use vortex_dtype::Nullability::NonNullable;
10use vortex_dtype::{DType, NativePType, PType, match_each_integer_ptype};
11use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
12use vortex_mask::{AllOr, Mask};
13use vortex_scalar::{PValue, Scalar};
14
15use crate::aliases::hash_map::HashMap;
16use crate::arrays::PrimitiveArray;
17use crate::compute::{cast, filter, take};
18use crate::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
19use crate::vtable::ValidityHelper;
20use crate::{Array, ArrayRef, IntoArray, ToCanonical};
21
22#[derive(Copy, Clone, Serialize, Deserialize, prost::Message)]
23pub struct PatchesMetadata {
24 #[prost(uint64, tag = "1")]
25 len: u64,
26 #[prost(uint64, tag = "2")]
27 offset: u64,
28 #[prost(enumeration = "PType", tag = "3")]
29 indices_ptype: i32,
30}
31
32impl PatchesMetadata {
33 pub fn new(len: usize, offset: usize, indices_ptype: PType) -> Self {
34 Self {
35 len: len as u64,
36 offset: offset as u64,
37 indices_ptype: indices_ptype as i32,
38 }
39 }
40
41 #[inline]
42 pub fn len(&self) -> usize {
43 usize::try_from(self.len).vortex_expect("len is a valid usize")
44 }
45
46 #[inline]
47 pub fn is_empty(&self) -> bool {
48 self.len == 0
49 }
50
51 #[inline]
52 pub fn offset(&self) -> usize {
53 usize::try_from(self.offset).vortex_expect("offset is a valid usize")
54 }
55
56 #[inline]
57 pub fn indices_dtype(&self) -> DType {
58 assert!(
59 self.indices_ptype().is_unsigned_int(),
60 "Patch indices must be unsigned integers"
61 );
62 DType::Primitive(self.indices_ptype(), NonNullable)
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct Patches {
69 array_len: usize,
70 offset: usize,
71 indices: ArrayRef,
72 values: ArrayRef,
73}
74
75impl Patches {
76 pub fn new(array_len: usize, offset: usize, indices: ArrayRef, values: ArrayRef) -> Self {
77 assert_eq!(
78 indices.len(),
79 values.len(),
80 "Patch indices and values must have the same length"
81 );
82 assert!(
83 indices.dtype().is_unsigned_int(),
84 "Patch indices must be unsigned integers"
85 );
86 assert!(
87 indices.len() <= array_len,
88 "Patch indices must be shorter than the array length"
89 );
90 assert!(!indices.is_empty(), "Patch indices must not be empty");
91 let max = usize::try_from(
92 &indices
93 .scalar_at(indices.len() - 1)
94 .vortex_expect("indices are not empty"),
95 )
96 .vortex_expect("indices must be a number");
97 assert!(
98 max - offset < array_len,
99 "Patch indices {:?}, offset {} are longer than the array length {}",
100 max,
101 offset,
102 array_len
103 );
104 Self::new_unchecked(array_len, offset, indices, values)
105 }
106
107 pub fn new_unchecked(
117 array_len: usize,
118 offset: usize,
119 indices: ArrayRef,
120 values: ArrayRef,
121 ) -> Self {
122 Self {
123 array_len,
124 offset,
125 indices,
126 values,
127 }
128 }
129
130 pub fn into_parts(self) -> (usize, usize, ArrayRef, ArrayRef) {
132 (self.array_len, self.offset, self.indices, self.values)
133 }
134
135 pub fn array_len(&self) -> usize {
136 self.array_len
137 }
138
139 pub fn num_patches(&self) -> usize {
140 self.indices.len()
141 }
142
143 pub fn dtype(&self) -> &DType {
144 self.values.dtype()
145 }
146
147 pub fn indices(&self) -> &ArrayRef {
148 &self.indices
149 }
150
151 pub fn into_indices(self) -> ArrayRef {
152 self.indices
153 }
154
155 pub fn indices_mut(&mut self) -> &mut ArrayRef {
156 &mut self.indices
157 }
158
159 pub fn values(&self) -> &ArrayRef {
160 &self.values
161 }
162
163 pub fn into_values(self) -> ArrayRef {
164 self.values
165 }
166
167 pub fn values_mut(&mut self) -> &mut ArrayRef {
168 &mut self.values
169 }
170
171 pub fn offset(&self) -> usize {
172 self.offset
173 }
174
175 pub fn indices_ptype(&self) -> PType {
176 PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
177 }
178
179 pub fn to_metadata(&self, len: usize, dtype: &DType) -> VortexResult<PatchesMetadata> {
180 if self.indices.len() > len {
181 vortex_bail!(
182 "Patch indices {} are longer than the array length {}",
183 self.indices.len(),
184 len
185 );
186 }
187 if self.values.dtype() != dtype {
188 vortex_bail!(
189 "Patch values dtype {} does not match array dtype {}",
190 self.values.dtype(),
191 dtype
192 );
193 }
194 Ok(PatchesMetadata {
195 len: self.indices.len() as u64,
196 offset: self.offset as u64,
197 indices_ptype: PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
198 as i32,
199 })
200 }
201
202 pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
203 Ok(Self::new_unchecked(
204 self.array_len,
205 self.offset,
206 self.indices,
207 cast(&self.values, values_dtype)?,
208 ))
209 }
210
211 pub fn get_patched(&self, index: usize) -> VortexResult<Option<Scalar>> {
213 if let Some(patch_idx) = self.search_index(index)?.to_found() {
214 self.values().scalar_at(patch_idx).map(Some)
215 } else {
216 Ok(None)
217 }
218 }
219
220 pub fn search_index(&self, index: usize) -> VortexResult<SearchResult> {
222 Ok(self.indices.as_primitive_typed().search_sorted(
223 &PValue::U64((index + self.offset) as u64),
224 SearchSortedSide::Left,
225 ))
226 }
227
228 pub fn search_sorted<T: Into<Scalar>>(
230 &self,
231 target: T,
232 side: SearchSortedSide,
233 ) -> VortexResult<SearchResult> {
234 let target = target.into();
235
236 let sr = if self.values().dtype().is_primitive() {
237 self.values()
238 .as_primitive_typed()
239 .search_sorted(&target.as_primitive().pvalue(), side)
240 } else {
241 self.values().search_sorted(&target, side)
242 };
243
244 let index_idx = sr.to_offsets_index(self.indices().len(), side);
245 let index = usize::try_from(&self.indices().scalar_at(index_idx)?)? - self.offset;
246 Ok(match sr {
247 SearchResult::Found(i) => SearchResult::Found(
249 if i == self.indices().len() || side == SearchSortedSide::Right {
250 index + 1
251 } else {
252 index
253 },
254 ),
255 SearchResult::NotFound(i) => {
257 SearchResult::NotFound(if i == 0 { index } else { index + 1 })
258 }
259 })
260 }
261
262 pub fn min_index(&self) -> VortexResult<usize> {
264 Ok(usize::try_from(&self.indices().scalar_at(0)?)? - self.offset)
265 }
266
267 pub fn max_index(&self) -> VortexResult<usize> {
269 Ok(usize::try_from(&self.indices().scalar_at(self.indices().len() - 1)?)? - self.offset)
270 }
271
272 pub fn filter(&self, mask: &Mask) -> VortexResult<Option<Self>> {
274 match mask.indices() {
275 AllOr::All => Ok(Some(self.clone())),
276 AllOr::None => Ok(None),
277 AllOr::Some(mask_indices) => {
278 let flat_indices = self.indices().to_primitive()?;
279 match_each_integer_ptype!(flat_indices.ptype(), |$I| {
280 filter_patches_with_mask(
281 flat_indices.as_slice::<$I>(),
282 self.offset(),
283 self.values(),
284 mask_indices,
285 )
286 })
287 }
288 }
289 }
290
291 pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Option<Self>> {
293 let patch_start = self.search_index(start)?.to_index();
294 let patch_stop = self.search_index(stop)?.to_index();
295
296 if patch_start == patch_stop {
297 return Ok(None);
298 }
299
300 let values = self.values().slice(patch_start, patch_stop)?;
302 let indices = self.indices().slice(patch_start, patch_stop)?;
303
304 Ok(Some(Self::new(
305 stop - start,
306 start + self.offset(),
307 indices,
308 values,
309 )))
310 }
311
312 const PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN: f64 = 5.0;
314
315 fn is_map_faster_than_search(&self, take_indices: &PrimitiveArray) -> bool {
316 (self.num_patches() as f64 / take_indices.len() as f64)
317 < Self::PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN
318 }
319
320 pub fn take(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
322 if take_indices.is_empty() {
323 return Ok(None);
324 }
325 let take_indices = take_indices.to_primitive()?;
326 if self.is_map_faster_than_search(&take_indices) {
327 self.take_map(take_indices)
328 } else {
329 self.take_search(take_indices)
330 }
331 }
332
333 pub fn take_search(&self, take_indices: PrimitiveArray) -> VortexResult<Option<Self>> {
334 let indices = self.indices.to_primitive()?;
335 let new_length = take_indices.len();
336
337 let Some((new_indices, values_indices)) = match_each_integer_ptype!(indices.ptype(), |$INDICES| {
338 match_each_integer_ptype!(take_indices.ptype(), |$TAKE_INDICES| {
339 take_search::<_, $TAKE_INDICES>(indices.as_slice::<$INDICES>(), take_indices, self.offset())?
340 })
341 }) else {
342 return Ok(None);
343 };
344
345 Ok(Some(Self::new(
346 new_length,
347 0,
348 new_indices,
349 take(self.values(), &values_indices)?,
350 )))
351 }
352
353 pub fn take_map(&self, take_indices: PrimitiveArray) -> VortexResult<Option<Self>> {
354 let indices = self.indices.to_primitive()?;
355 let new_length = take_indices.len();
356
357 let Some((new_sparse_indices, value_indices)) = match_each_integer_ptype!(self.indices_ptype(), |$INDICES| {
358 match_each_integer_ptype!(take_indices.ptype(), |$TAKE_INDICES| {
359 take_map::<_, $TAKE_INDICES>(indices.as_slice::<$INDICES>(), take_indices, self.offset(), self.min_index()?, self.max_index()?)?
360 })
361 }) else {
362 return Ok(None);
363 };
364
365 Ok(Some(Patches::new(
366 new_length,
367 0,
368 new_sparse_indices,
369 take(self.values(), &value_indices)?,
370 )))
371 }
372
373 pub fn map_values<F>(self, f: F) -> VortexResult<Self>
374 where
375 F: FnOnce(ArrayRef) -> VortexResult<ArrayRef>,
376 {
377 let values = f(self.values)?;
378 if self.indices.len() != values.len() {
379 vortex_bail!(
380 "map_values must preserve length: expected {} received {}",
381 self.indices.len(),
382 values.len()
383 )
384 }
385 Ok(Self::new(self.array_len, self.offset, self.indices, values))
386 }
387}
388
389fn take_search<I: NativePType + NumCast + PartialOrd, T: NativePType + NumCast>(
390 indices: &[I],
391 take_indices: PrimitiveArray,
392 indices_offset: usize,
393) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
394where
395 usize: TryFrom<T>,
396 VortexError: From<<usize as TryFrom<T>>::Error>,
397{
398 let take_indices_validity = take_indices.validity();
399 let indices_offset = I::from(indices_offset).vortex_expect("indices_offset out of range");
400
401 let (values_indices, new_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
402 .as_slice::<T>()
403 .iter()
404 .map(|v| {
405 match I::from(*v) {
406 None => {
407 SearchResult::NotFound(indices.len())
409 }
410 Some(v) => indices.search_sorted(&(v + indices_offset), SearchSortedSide::Left),
411 }
412 })
413 .enumerate()
414 .filter_map(|(idx_in_take, search_result)| {
415 search_result
416 .to_found()
417 .map(|patch_idx| (patch_idx as u64, idx_in_take as u64))
418 })
419 .unzip();
420
421 if new_indices.is_empty() {
422 return Ok(None);
423 }
424
425 let new_indices = new_indices.into_array();
426 let values_validity = take_indices_validity.take(&new_indices)?;
427 Ok(Some((
428 new_indices,
429 PrimitiveArray::new(values_indices, values_validity).into_array(),
430 )))
431}
432
433fn take_map<I: NativePType + Hash + Eq + TryFrom<usize>, T: NativePType>(
434 indices: &[I],
435 take_indices: PrimitiveArray,
436 indices_offset: usize,
437 min_index: usize,
438 max_index: usize,
439) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
440where
441 usize: TryFrom<T>,
442 VortexError: From<<I as TryFrom<usize>>::Error>,
443{
444 let take_indices_validity = take_indices.validity();
445 let take_indices = take_indices.as_slice::<T>();
446 let offset_i = I::try_from(indices_offset)?;
447
448 let sparse_index_to_value_index: HashMap<I, usize> = indices
449 .iter()
450 .copied()
451 .map(|idx| idx - offset_i)
452 .enumerate()
453 .map(|(value_index, sparse_index)| (sparse_index, value_index))
454 .collect();
455 let (new_sparse_indices, value_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
456 .iter()
457 .copied()
458 .map(usize::try_from)
459 .process_results(|iter| {
460 iter.enumerate()
461 .filter(|(_, ti)| *ti >= min_index && *ti <= max_index)
462 .filter_map(|(new_sparse_index, take_sparse_index)| {
463 sparse_index_to_value_index
464 .get(
465 &I::try_from(take_sparse_index)
466 .vortex_expect("take_sparse_index is between min and max index"),
467 )
468 .map(|value_index| (new_sparse_index as u64, *value_index as u64))
469 })
470 .unzip()
471 })
472 .map_err(|_| vortex_err!("Failed to convert index to usize"))?;
473
474 if new_sparse_indices.is_empty() {
475 return Ok(None);
476 }
477
478 let new_sparse_indices = new_sparse_indices.into_array();
479 let values_validity = take_indices_validity.take(&new_sparse_indices)?;
480 Ok(Some((
481 new_sparse_indices,
482 PrimitiveArray::new(value_indices, values_validity).into_array(),
483 )))
484}
485
486fn filter_patches_with_mask<T: ToPrimitive + Copy + Ord>(
492 patch_indices: &[T],
493 offset: usize,
494 patch_values: &dyn Array,
495 mask_indices: &[usize],
496) -> VortexResult<Option<Patches>> {
497 let true_count = mask_indices.len();
498 let mut new_patch_indices = BufferMut::<u64>::with_capacity(true_count);
499 let mut new_mask_indices = Vec::with_capacity(true_count);
500
501 const STRIDE: usize = 4;
505
506 let mut mask_idx = 0usize;
507 let mut true_idx = 0usize;
508
509 while mask_idx < patch_indices.len() && true_idx < true_count {
510 if (mask_idx + STRIDE) < patch_indices.len() && (true_idx + STRIDE) < mask_indices.len() {
517 let left_min = patch_indices[mask_idx].to_usize().vortex_expect("left_min") - offset;
519 let left_max = patch_indices[mask_idx + STRIDE]
520 .to_usize()
521 .vortex_expect("left_max")
522 - offset;
523 let right_min = mask_indices[true_idx];
524 let right_max = mask_indices[true_idx + STRIDE];
525
526 if left_min > right_max {
527 true_idx += STRIDE;
529 continue;
530 } else if right_min > left_max {
531 mask_idx += STRIDE;
532 continue;
533 } else {
534 }
536 }
537
538 let left = patch_indices[mask_idx].to_usize().vortex_expect("left") - offset;
541 let right = mask_indices[true_idx];
542
543 match left.cmp(&right) {
544 Ordering::Less => {
545 mask_idx += 1;
546 }
547 Ordering::Greater => {
548 true_idx += 1;
549 }
550 Ordering::Equal => {
551 new_mask_indices.push(mask_idx);
553 new_patch_indices.push(true_idx as u64);
554
555 mask_idx += 1;
556 true_idx += 1;
557 }
558 }
559 }
560
561 if new_mask_indices.is_empty() {
562 return Ok(None);
563 }
564
565 let new_patch_indices = new_patch_indices.into_array();
566 let new_patch_values = filter(
567 patch_values,
568 &Mask::from_indices(patch_values.len(), new_mask_indices),
569 )?;
570
571 Ok(Some(Patches::new(
572 true_count,
573 0,
574 new_patch_indices,
575 new_patch_values,
576 )))
577}
578
579#[cfg(test)]
580mod test {
581 use rstest::{fixture, rstest};
582 use vortex_buffer::buffer;
583 use vortex_mask::Mask;
584
585 use crate::arrays::PrimitiveArray;
586 use crate::patches::Patches;
587 use crate::search_sorted::{SearchResult, SearchSortedSide};
588 use crate::validity::Validity;
589 use crate::{IntoArray, ToCanonical};
590
591 #[test]
592 fn test_filter() {
593 let patches = Patches::new(
594 100,
595 0,
596 buffer![10u32, 11, 20].into_array(),
597 buffer![100, 110, 200].into_array(),
598 );
599
600 let filtered = patches
601 .filter(&Mask::from_indices(100, vec![10, 20, 30]))
602 .unwrap()
603 .unwrap();
604
605 let indices = filtered.indices().to_primitive().unwrap();
606 let values = filtered.values().to_primitive().unwrap();
607 assert_eq!(indices.as_slice::<u64>(), &[0, 1]);
608 assert_eq!(values.as_slice::<i32>(), &[100, 200]);
609 }
610
611 #[fixture]
612 fn patches() -> Patches {
613 Patches::new(
614 20,
615 0,
616 buffer![2u64, 9, 15].into_array(),
617 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
618 )
619 }
620
621 #[rstest]
622 fn search_larger_than(patches: Patches) {
623 let res = patches.search_sorted(66, SearchSortedSide::Left).unwrap();
624 assert_eq!(res, SearchResult::NotFound(16));
625 }
626
627 #[rstest]
628 fn search_less_than(patches: Patches) {
629 let res = patches.search_sorted(22, SearchSortedSide::Left).unwrap();
630 assert_eq!(res, SearchResult::NotFound(2));
631 }
632
633 #[rstest]
634 fn search_found(patches: Patches) {
635 let res = patches.search_sorted(44, SearchSortedSide::Left).unwrap();
636 assert_eq!(res, SearchResult::Found(9));
637 }
638
639 #[rstest]
640 fn search_not_found_right(patches: Patches) {
641 let res = patches.search_sorted(56, SearchSortedSide::Right).unwrap();
642 assert_eq!(res, SearchResult::NotFound(16));
643 }
644
645 #[rstest]
646 fn search_sliced(patches: Patches) {
647 let sliced = patches.slice(7, 20).unwrap().unwrap();
648 assert_eq!(
649 sliced.search_sorted(22, SearchSortedSide::Left).unwrap(),
650 SearchResult::NotFound(2)
651 );
652 }
653
654 #[test]
655 fn search_right() {
656 let patches = Patches::new(
657 6,
658 0,
659 buffer![0u8, 1, 4, 5].into_array(),
660 buffer![-128i8, -98, 8, 50].into_array(),
661 );
662
663 assert_eq!(
664 patches.search_sorted(-98, SearchSortedSide::Right).unwrap(),
665 SearchResult::Found(2)
666 );
667 assert_eq!(
668 patches.search_sorted(50, SearchSortedSide::Right).unwrap(),
669 SearchResult::Found(6),
670 );
671 assert_eq!(
672 patches.search_sorted(7, SearchSortedSide::Right).unwrap(),
673 SearchResult::NotFound(2),
674 );
675 assert_eq!(
676 patches.search_sorted(51, SearchSortedSide::Right).unwrap(),
677 SearchResult::NotFound(6)
678 );
679 }
680
681 #[test]
682 fn search_left() {
683 let patches = Patches::new(
684 20,
685 0,
686 buffer![0u64, 1, 17, 18, 19].into_array(),
687 buffer![11i32, 22, 33, 44, 55].into_array(),
688 );
689 assert_eq!(
690 patches.search_sorted(30, SearchSortedSide::Left).unwrap(),
691 SearchResult::NotFound(2)
692 );
693 assert_eq!(
694 patches.search_sorted(54, SearchSortedSide::Left).unwrap(),
695 SearchResult::NotFound(19)
696 );
697 }
698
699 #[rstest]
700 fn take_wit_nulls(patches: Patches) {
701 let taken = patches
702 .take(
703 &PrimitiveArray::new(buffer![9, 0], Validity::from_iter(vec![true, false]))
704 .into_array(),
705 )
706 .unwrap()
707 .unwrap();
708 let primitive_values = taken.values().to_primitive().unwrap();
709 assert_eq!(taken.array_len(), 2);
710 assert_eq!(primitive_values.as_slice::<i32>(), [44]);
711 assert_eq!(
712 primitive_values.validity_mask().unwrap(),
713 Mask::from_iter(vec![true])
714 );
715 }
716
717 #[test]
718 fn test_slice() {
719 let values = buffer![15_u32, 135, 13531, 42].into_array();
720 let indices = buffer![10_u64, 11, 50, 100].into_array();
721
722 let patches = Patches::new(101, 0, indices, values);
723
724 let sliced = patches.slice(15, 100).unwrap().unwrap();
725 assert_eq!(sliced.array_len(), 100 - 15);
726 let primitive = sliced.values().to_primitive().unwrap();
727
728 assert_eq!(primitive.as_slice::<u32>(), &[13531]);
729 }
730
731 #[test]
732 fn doubly_sliced() {
733 let values = buffer![15_u32, 135, 13531, 42].into_array();
734 let indices = buffer![10_u64, 11, 50, 100].into_array();
735
736 let patches = Patches::new(101, 0, indices, values);
737
738 let sliced = patches.slice(15, 100).unwrap().unwrap();
739 assert_eq!(sliced.array_len(), 100 - 15);
740 let primitive = sliced.values().to_primitive().unwrap();
741
742 assert_eq!(primitive.as_slice::<u32>(), &[13531]);
743
744 let doubly_sliced = sliced.slice(35, 36).unwrap().unwrap();
745 let primitive_doubly_sliced = doubly_sliced.values().to_primitive().unwrap();
746
747 assert_eq!(primitive_doubly_sliced.as_slice::<u32>(), &[13531]);
748 }
749}