1use std::cmp::Ordering;
2use std::fmt::Debug;
3use std::hash::Hash;
4
5use itertools::Itertools as _;
6use num_traits::{NumCast, ToPrimitive};
7use serde::{Deserialize, Serialize};
8use vortex_buffer::BufferMut;
9use vortex_dtype::Nullability::NonNullable;
10use vortex_dtype::{
11 DType, NativePType, PType, match_each_integer_ptype, match_each_unsigned_integer_ptype,
12};
13use vortex_error::{
14 VortexError, VortexExpect, VortexResult, VortexUnwrap, vortex_bail, vortex_err,
15};
16use vortex_mask::{AllOr, Mask};
17use vortex_scalar::{PValue, Scalar};
18use vortex_utils::aliases::hash_map::HashMap;
19
20use crate::arrays::PrimitiveArray;
21use crate::compute::{cast, filter, take};
22use crate::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
23use crate::vtable::ValidityHelper;
24use crate::{Array, ArrayRef, IntoArray, ToCanonical};
25
26#[derive(Copy, Clone, Serialize, Deserialize, prost::Message)]
27pub struct PatchesMetadata {
28 #[prost(uint64, tag = "1")]
29 len: u64,
30 #[prost(uint64, tag = "2")]
31 offset: u64,
32 #[prost(enumeration = "PType", tag = "3")]
33 indices_ptype: i32,
34}
35
36impl PatchesMetadata {
37 pub fn new(len: usize, offset: usize, indices_ptype: PType) -> Self {
38 Self {
39 len: len as u64,
40 offset: offset as u64,
41 indices_ptype: indices_ptype as i32,
42 }
43 }
44
45 #[inline]
46 pub fn len(&self) -> usize {
47 usize::try_from(self.len).vortex_expect("len is a valid usize")
48 }
49
50 #[inline]
51 pub fn is_empty(&self) -> bool {
52 self.len == 0
53 }
54
55 #[inline]
56 pub fn offset(&self) -> usize {
57 usize::try_from(self.offset).vortex_expect("offset is a valid usize")
58 }
59
60 #[inline]
61 pub fn indices_dtype(&self) -> DType {
62 assert!(
63 self.indices_ptype().is_unsigned_int(),
64 "Patch indices must be unsigned integers"
65 );
66 DType::Primitive(self.indices_ptype(), NonNullable)
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct Patches {
73 array_len: usize,
74 offset: usize,
75 indices: ArrayRef,
76 values: ArrayRef,
77}
78
79impl Patches {
80 pub fn new(array_len: usize, offset: usize, indices: ArrayRef, values: ArrayRef) -> Self {
81 assert_eq!(
82 indices.len(),
83 values.len(),
84 "Patch indices and values must have the same length"
85 );
86 assert!(
87 indices.dtype().is_unsigned_int(),
88 "Patch indices must be unsigned integers"
89 );
90 assert!(
91 indices.len() <= array_len,
92 "Patch indices must be shorter than the array length"
93 );
94 assert!(!indices.is_empty(), "Patch indices must not be empty");
95 let max = usize::try_from(
96 &indices
97 .scalar_at(indices.len() - 1)
98 .vortex_expect("indices are not empty"),
99 )
100 .vortex_expect("indices must be a number");
101 assert!(
102 max - offset < array_len,
103 "Patch indices {max:?}, offset {offset} are longer than the array length {array_len}"
104 );
105 Self::new_unchecked(array_len, offset, indices, values)
106 }
107
108 pub fn new_unchecked(
118 array_len: usize,
119 offset: usize,
120 indices: ArrayRef,
121 values: ArrayRef,
122 ) -> Self {
123 Self {
124 array_len,
125 offset,
126 indices,
127 values,
128 }
129 }
130
131 pub fn array_len(&self) -> usize {
132 self.array_len
133 }
134
135 pub fn num_patches(&self) -> usize {
136 self.indices.len()
137 }
138
139 pub fn dtype(&self) -> &DType {
140 self.values.dtype()
141 }
142
143 pub fn indices(&self) -> &ArrayRef {
144 &self.indices
145 }
146
147 pub fn into_indices(self) -> ArrayRef {
148 self.indices
149 }
150
151 pub fn indices_mut(&mut self) -> &mut ArrayRef {
152 &mut self.indices
153 }
154
155 pub fn values(&self) -> &ArrayRef {
156 &self.values
157 }
158
159 pub fn into_values(self) -> ArrayRef {
160 self.values
161 }
162
163 pub fn values_mut(&mut self) -> &mut ArrayRef {
164 &mut self.values
165 }
166
167 pub fn offset(&self) -> usize {
168 self.offset
169 }
170
171 pub fn indices_ptype(&self) -> PType {
172 PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
173 }
174
175 pub fn to_metadata(&self, len: usize, dtype: &DType) -> VortexResult<PatchesMetadata> {
176 if self.indices.len() > len {
177 vortex_bail!(
178 "Patch indices {} are longer than the array length {}",
179 self.indices.len(),
180 len
181 );
182 }
183 if self.values.dtype() != dtype {
184 vortex_bail!(
185 "Patch values dtype {} does not match array dtype {}",
186 self.values.dtype(),
187 dtype
188 );
189 }
190 Ok(PatchesMetadata {
191 len: self.indices.len() as u64,
192 offset: self.offset as u64,
193 indices_ptype: PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
194 as i32,
195 })
196 }
197
198 pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
199 Ok(Self::new_unchecked(
200 self.array_len,
201 self.offset,
202 self.indices,
203 cast(&self.values, values_dtype)?,
204 ))
205 }
206
207 pub fn get_patched(&self, index: usize) -> VortexResult<Option<Scalar>> {
209 if let Some(patch_idx) = self.search_index(index)?.to_found() {
210 self.values().scalar_at(patch_idx).map(Some)
211 } else {
212 Ok(None)
213 }
214 }
215
216 pub fn search_index(&self, index: usize) -> VortexResult<SearchResult> {
218 Ok(self.indices.as_primitive_typed().search_sorted(
219 &PValue::U64((index + self.offset) as u64),
220 SearchSortedSide::Left,
221 ))
222 }
223
224 pub fn search_sorted<T: Into<Scalar>>(
226 &self,
227 target: T,
228 side: SearchSortedSide,
229 ) -> VortexResult<SearchResult> {
230 let target = target.into();
231
232 let sr = if self.values().dtype().is_primitive() {
233 self.values()
234 .as_primitive_typed()
235 .search_sorted(&target.as_primitive().pvalue(), side)
236 } else {
237 self.values().search_sorted(&target, side)
238 };
239
240 let index_idx = sr.to_offsets_index(self.indices().len(), side);
241 let index = usize::try_from(&self.indices().scalar_at(index_idx)?)? - self.offset;
242 Ok(match sr {
243 SearchResult::Found(i) => SearchResult::Found(
245 if i == self.indices().len() || side == SearchSortedSide::Right {
246 index + 1
247 } else {
248 index
249 },
250 ),
251 SearchResult::NotFound(i) => {
253 SearchResult::NotFound(if i == 0 { index } else { index + 1 })
254 }
255 })
256 }
257
258 pub fn min_index(&self) -> VortexResult<usize> {
260 Ok(usize::try_from(&self.indices().scalar_at(0)?)? - self.offset)
261 }
262
263 pub fn max_index(&self) -> VortexResult<usize> {
265 Ok(usize::try_from(&self.indices().scalar_at(self.indices().len() - 1)?)? - self.offset)
266 }
267
268 pub fn filter(&self, mask: &Mask) -> VortexResult<Option<Self>> {
270 match mask.indices() {
271 AllOr::All => Ok(Some(self.clone())),
272 AllOr::None => Ok(None),
273 AllOr::Some(mask_indices) => {
274 let flat_indices = self.indices().to_primitive()?;
275 match_each_unsigned_integer_ptype!(flat_indices.ptype(), |I| {
276 filter_patches_with_mask(
277 flat_indices.as_slice::<I>(),
278 self.offset(),
279 self.values(),
280 mask_indices,
281 )
282 })
283 }
284 }
285 }
286
287 pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Option<Self>> {
289 let patch_start = self.search_index(start)?.to_index();
290 let patch_stop = self.search_index(stop)?.to_index();
291
292 if patch_start == patch_stop {
293 return Ok(None);
294 }
295
296 let values = self.values().slice(patch_start, patch_stop)?;
298 let indices = self.indices().slice(patch_start, patch_stop)?;
299
300 Ok(Some(Self::new(
301 stop - start,
302 start + self.offset(),
303 indices,
304 values,
305 )))
306 }
307
308 const PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN: f64 = 5.0;
310
311 fn is_map_faster_than_search(&self, take_indices: &PrimitiveArray) -> bool {
312 (self.num_patches() as f64 / take_indices.len() as f64)
313 < Self::PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN
314 }
315
316 pub fn take_with_nulls(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
320 if take_indices.is_empty() {
321 return Ok(None);
322 }
323
324 let take_indices = take_indices.to_primitive()?;
325 if self.is_map_faster_than_search(&take_indices) {
326 self.take_map(take_indices, true)
327 } else {
328 self.take_search(take_indices, true)
329 }
330 }
331
332 pub fn take(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
336 if take_indices.is_empty() {
337 return Ok(None);
338 }
339
340 let take_indices = take_indices.to_primitive()?;
341 if self.is_map_faster_than_search(&take_indices) {
342 self.take_map(take_indices, false)
343 } else {
344 self.take_search(take_indices, false)
345 }
346 }
347
348 pub fn take_search(
349 &self,
350 take_indices: PrimitiveArray,
351 include_nulls: bool,
352 ) -> VortexResult<Option<Self>> {
353 let indices = self.indices.to_primitive()?;
354 let new_length = take_indices.len();
355
356 let Some((new_indices, values_indices)) =
357 match_each_unsigned_integer_ptype!(indices.ptype(), |Indices| {
358 match_each_integer_ptype!(take_indices.ptype(), |TakeIndices| {
359 take_search::<_, TakeIndices>(
360 indices.as_slice::<Indices>(),
361 take_indices,
362 self.offset(),
363 include_nulls,
364 )?
365 })
366 })
367 else {
368 return Ok(None);
369 };
370
371 Ok(Some(Self::new(
372 new_length,
373 0,
374 new_indices,
375 take(self.values(), &values_indices)?,
376 )))
377 }
378
379 pub fn take_map(
380 &self,
381 take_indices: PrimitiveArray,
382 include_nulls: bool,
383 ) -> VortexResult<Option<Self>> {
384 let indices = self.indices.to_primitive()?;
385 let new_length = take_indices.len();
386
387 let Some((new_sparse_indices, value_indices)) =
388 match_each_unsigned_integer_ptype!(indices.ptype(), |Indices| {
389 match_each_integer_ptype!(take_indices.ptype(), |TakeIndices| {
390 take_map::<_, TakeIndices>(
391 indices.as_slice::<Indices>(),
392 take_indices,
393 self.offset(),
394 self.min_index()?,
395 self.max_index()?,
396 include_nulls,
397 )?
398 })
399 })
400 else {
401 return Ok(None);
402 };
403
404 Ok(Some(Patches::new(
405 new_length,
406 0,
407 new_sparse_indices,
408 take(self.values(), &value_indices)?,
409 )))
410 }
411
412 pub fn map_values<F>(self, f: F) -> VortexResult<Self>
413 where
414 F: FnOnce(ArrayRef) -> VortexResult<ArrayRef>,
415 {
416 let values = f(self.values)?;
417 if self.indices.len() != values.len() {
418 vortex_bail!(
419 "map_values must preserve length: expected {} received {}",
420 self.indices.len(),
421 values.len()
422 )
423 }
424 Ok(Self::new(self.array_len, self.offset, self.indices, values))
425 }
426}
427
428fn take_search<I: NativePType + NumCast + PartialOrd, T: NativePType + NumCast>(
429 indices: &[I],
430 take_indices: PrimitiveArray,
431 indices_offset: usize,
432 include_nulls: bool,
433) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
434where
435 usize: TryFrom<T>,
436 VortexError: From<<usize as TryFrom<T>>::Error>,
437{
438 let take_indices_validity = take_indices.validity();
439 let indices_offset = I::from(indices_offset).vortex_expect("indices_offset out of range");
440
441 let (values_indices, new_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
442 .as_slice::<T>()
443 .iter()
444 .enumerate()
445 .filter_map(|(i, &v)| {
446 I::from(v)
447 .and_then(|v| {
448 if include_nulls && take_indices_validity.is_null(i).vortex_unwrap() {
450 Some(0)
451 } else {
452 indices
453 .search_sorted(&(v + indices_offset), SearchSortedSide::Left)
454 .to_found()
455 .map(|patch_idx| patch_idx as u64)
456 }
457 })
458 .map(|patch_idx| (patch_idx, i as u64))
459 })
460 .unzip();
461
462 if new_indices.is_empty() {
463 return Ok(None);
464 }
465
466 let new_indices = new_indices.into_array();
467 let values_validity = take_indices_validity.take(&new_indices)?;
468 Ok(Some((
469 new_indices,
470 PrimitiveArray::new(values_indices, values_validity).into_array(),
471 )))
472}
473
474fn take_map<I: NativePType + Hash + Eq + TryFrom<usize>, T: NativePType>(
475 indices: &[I],
476 take_indices: PrimitiveArray,
477 indices_offset: usize,
478 min_index: usize,
479 max_index: usize,
480 include_nulls: bool,
481) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
482where
483 usize: TryFrom<T>,
484 VortexError: From<<I as TryFrom<usize>>::Error>,
485{
486 let take_indices_validity = take_indices.validity();
487 let take_indices = take_indices.as_slice::<T>();
488 let offset_i = I::try_from(indices_offset)?;
489
490 let sparse_index_to_value_index: HashMap<I, usize> = indices
491 .iter()
492 .copied()
493 .map(|idx| idx - offset_i)
494 .enumerate()
495 .map(|(value_index, sparse_index)| (sparse_index, value_index))
496 .collect();
497
498 let (new_sparse_indices, value_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
499 .iter()
500 .copied()
501 .map(usize::try_from)
502 .process_results(|iter| {
503 iter.enumerate()
504 .filter_map(|(idx_in_take, ti)| {
505 if include_nulls && take_indices_validity.is_null(idx_in_take).vortex_unwrap() {
507 Some((idx_in_take as u64, 0))
508 } else if ti < min_index || ti > max_index {
509 None
510 } else {
511 sparse_index_to_value_index
512 .get(
513 &I::try_from(ti)
514 .vortex_expect("take index is between min and max index"),
515 )
516 .map(|value_index| (idx_in_take as u64, *value_index as u64))
517 }
518 })
519 .unzip()
520 })
521 .map_err(|_| vortex_err!("Failed to convert index to usize"))?;
522
523 if new_sparse_indices.is_empty() {
524 return Ok(None);
525 }
526
527 let new_sparse_indices = new_sparse_indices.into_array();
528 let values_validity = take_indices_validity.take(&new_sparse_indices)?;
529 Ok(Some((
530 new_sparse_indices,
531 PrimitiveArray::new(value_indices, values_validity).into_array(),
532 )))
533}
534
535fn filter_patches_with_mask<T: ToPrimitive + Copy + Ord>(
541 patch_indices: &[T],
542 offset: usize,
543 patch_values: &dyn Array,
544 mask_indices: &[usize],
545) -> VortexResult<Option<Patches>> {
546 let true_count = mask_indices.len();
547 let mut new_patch_indices = BufferMut::<u64>::with_capacity(true_count);
548 let mut new_mask_indices = Vec::with_capacity(true_count);
549
550 const STRIDE: usize = 4;
554
555 let mut mask_idx = 0usize;
556 let mut true_idx = 0usize;
557
558 while mask_idx < patch_indices.len() && true_idx < true_count {
559 if (mask_idx + STRIDE) < patch_indices.len() && (true_idx + STRIDE) < mask_indices.len() {
566 let left_min = patch_indices[mask_idx].to_usize().vortex_expect("left_min") - offset;
568 let left_max = patch_indices[mask_idx + STRIDE]
569 .to_usize()
570 .vortex_expect("left_max")
571 - offset;
572 let right_min = mask_indices[true_idx];
573 let right_max = mask_indices[true_idx + STRIDE];
574
575 if left_min > right_max {
576 true_idx += STRIDE;
578 continue;
579 } else if right_min > left_max {
580 mask_idx += STRIDE;
581 continue;
582 } else {
583 }
585 }
586
587 let left = patch_indices[mask_idx].to_usize().vortex_expect("left") - offset;
590 let right = mask_indices[true_idx];
591
592 match left.cmp(&right) {
593 Ordering::Less => {
594 mask_idx += 1;
595 }
596 Ordering::Greater => {
597 true_idx += 1;
598 }
599 Ordering::Equal => {
600 new_mask_indices.push(mask_idx);
602 new_patch_indices.push(true_idx as u64);
603
604 mask_idx += 1;
605 true_idx += 1;
606 }
607 }
608 }
609
610 if new_mask_indices.is_empty() {
611 return Ok(None);
612 }
613
614 let new_patch_indices = new_patch_indices.into_array();
615 let new_patch_values = filter(
616 patch_values,
617 &Mask::from_indices(patch_values.len(), new_mask_indices),
618 )?;
619
620 Ok(Some(Patches::new(
621 true_count,
622 0,
623 new_patch_indices,
624 new_patch_values,
625 )))
626}
627
628#[cfg(test)]
629mod test {
630 use rstest::{fixture, rstest};
631 use vortex_buffer::buffer;
632 use vortex_mask::Mask;
633
634 use crate::arrays::PrimitiveArray;
635 use crate::patches::Patches;
636 use crate::search_sorted::{SearchResult, SearchSortedSide};
637 use crate::validity::Validity;
638 use crate::{IntoArray, ToCanonical};
639
640 #[test]
641 fn test_filter() {
642 let patches = Patches::new(
643 100,
644 0,
645 buffer![10u32, 11, 20].into_array(),
646 buffer![100, 110, 200].into_array(),
647 );
648
649 let filtered = patches
650 .filter(&Mask::from_indices(100, vec![10, 20, 30]))
651 .unwrap()
652 .unwrap();
653
654 let indices = filtered.indices().to_primitive().unwrap();
655 let values = filtered.values().to_primitive().unwrap();
656 assert_eq!(indices.as_slice::<u64>(), &[0, 1]);
657 assert_eq!(values.as_slice::<i32>(), &[100, 200]);
658 }
659
660 #[fixture]
661 fn patches() -> Patches {
662 Patches::new(
663 20,
664 0,
665 buffer![2u64, 9, 15].into_array(),
666 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
667 )
668 }
669
670 #[rstest]
671 fn search_larger_than(patches: Patches) {
672 let res = patches.search_sorted(66, SearchSortedSide::Left).unwrap();
673 assert_eq!(res, SearchResult::NotFound(16));
674 }
675
676 #[rstest]
677 fn search_less_than(patches: Patches) {
678 let res = patches.search_sorted(22, SearchSortedSide::Left).unwrap();
679 assert_eq!(res, SearchResult::NotFound(2));
680 }
681
682 #[rstest]
683 fn search_found(patches: Patches) {
684 let res = patches.search_sorted(44, SearchSortedSide::Left).unwrap();
685 assert_eq!(res, SearchResult::Found(9));
686 }
687
688 #[rstest]
689 fn search_not_found_right(patches: Patches) {
690 let res = patches.search_sorted(56, SearchSortedSide::Right).unwrap();
691 assert_eq!(res, SearchResult::NotFound(16));
692 }
693
694 #[rstest]
695 fn search_sliced(patches: Patches) {
696 let sliced = patches.slice(7, 20).unwrap().unwrap();
697 assert_eq!(
698 sliced.search_sorted(22, SearchSortedSide::Left).unwrap(),
699 SearchResult::NotFound(2)
700 );
701 }
702
703 #[test]
704 fn search_right() {
705 let patches = Patches::new(
706 6,
707 0,
708 buffer![0u8, 1, 4, 5].into_array(),
709 buffer![-128i8, -98, 8, 50].into_array(),
710 );
711
712 assert_eq!(
713 patches.search_sorted(-98, SearchSortedSide::Right).unwrap(),
714 SearchResult::Found(2)
715 );
716 assert_eq!(
717 patches.search_sorted(50, SearchSortedSide::Right).unwrap(),
718 SearchResult::Found(6),
719 );
720 assert_eq!(
721 patches.search_sorted(7, SearchSortedSide::Right).unwrap(),
722 SearchResult::NotFound(2),
723 );
724 assert_eq!(
725 patches.search_sorted(51, SearchSortedSide::Right).unwrap(),
726 SearchResult::NotFound(6)
727 );
728 }
729
730 #[test]
731 fn search_left() {
732 let patches = Patches::new(
733 20,
734 0,
735 buffer![0u64, 1, 17, 18, 19].into_array(),
736 buffer![11i32, 22, 33, 44, 55].into_array(),
737 );
738 assert_eq!(
739 patches.search_sorted(30, SearchSortedSide::Left).unwrap(),
740 SearchResult::NotFound(2)
741 );
742 assert_eq!(
743 patches.search_sorted(54, SearchSortedSide::Left).unwrap(),
744 SearchResult::NotFound(19)
745 );
746 }
747
748 #[rstest]
749 fn take_wit_nulls(patches: Patches) {
750 let taken = patches
751 .take(
752 &PrimitiveArray::new(buffer![9, 0], Validity::from_iter(vec![true, false]))
753 .into_array(),
754 )
755 .unwrap()
756 .unwrap();
757 let primitive_values = taken.values().to_primitive().unwrap();
758 assert_eq!(taken.array_len(), 2);
759 assert_eq!(primitive_values.as_slice::<i32>(), [44]);
760 assert_eq!(
761 primitive_values.validity_mask().unwrap(),
762 Mask::from_iter(vec![true])
763 );
764 }
765
766 #[test]
767 fn test_slice() {
768 let values = buffer![15_u32, 135, 13531, 42].into_array();
769 let indices = buffer![10_u64, 11, 50, 100].into_array();
770
771 let patches = Patches::new(101, 0, indices, values);
772
773 let sliced = patches.slice(15, 100).unwrap().unwrap();
774 assert_eq!(sliced.array_len(), 100 - 15);
775 let primitive = sliced.values().to_primitive().unwrap();
776
777 assert_eq!(primitive.as_slice::<u32>(), &[13531]);
778 }
779
780 #[test]
781 fn doubly_sliced() {
782 let values = buffer![15_u32, 135, 13531, 42].into_array();
783 let indices = buffer![10_u64, 11, 50, 100].into_array();
784
785 let patches = Patches::new(101, 0, indices, values);
786
787 let sliced = patches.slice(15, 100).unwrap().unwrap();
788 assert_eq!(sliced.array_len(), 100 - 15);
789 let primitive = sliced.values().to_primitive().unwrap();
790
791 assert_eq!(primitive.as_slice::<u32>(), &[13531]);
792
793 let doubly_sliced = sliced.slice(35, 36).unwrap().unwrap();
794 let primitive_doubly_sliced = doubly_sliced.values().to_primitive().unwrap();
795
796 assert_eq!(primitive_doubly_sliced.as_slice::<u32>(), &[13531]);
797 }
798}