1use std::cmp::Ordering;
5use std::fmt::Debug;
6use std::hash::Hash;
7
8use itertools::Itertools as _;
9use num_traits::{NumCast, ToPrimitive};
10use serde::{Deserialize, Serialize};
11use vortex_buffer::BufferMut;
12use vortex_dtype::Nullability::NonNullable;
13use vortex_dtype::{
14 DType, NativePType, PType, match_each_integer_ptype, match_each_unsigned_integer_ptype,
15};
16use vortex_error::{
17 VortexError, VortexExpect, VortexResult, VortexUnwrap, vortex_bail, vortex_err,
18};
19use vortex_mask::{AllOr, Mask};
20use vortex_scalar::{PValue, Scalar};
21use vortex_utils::aliases::hash_map::HashMap;
22
23use crate::arrays::PrimitiveArray;
24use crate::compute::{cast, filter, take};
25use crate::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
26use crate::vtable::ValidityHelper;
27use crate::{Array, ArrayRef, IntoArray, ToCanonical};
28
29#[derive(Copy, Clone, Serialize, Deserialize, prost::Message)]
30pub struct PatchesMetadata {
31 #[prost(uint64, tag = "1")]
32 len: u64,
33 #[prost(uint64, tag = "2")]
34 offset: u64,
35 #[prost(enumeration = "PType", tag = "3")]
36 indices_ptype: i32,
37}
38
39impl PatchesMetadata {
40 pub fn new(len: usize, offset: usize, indices_ptype: PType) -> Self {
41 Self {
42 len: len as u64,
43 offset: offset as u64,
44 indices_ptype: indices_ptype as i32,
45 }
46 }
47
48 #[inline]
49 pub fn len(&self) -> usize {
50 usize::try_from(self.len).vortex_expect("len is a valid usize")
51 }
52
53 #[inline]
54 pub fn is_empty(&self) -> bool {
55 self.len == 0
56 }
57
58 #[inline]
59 pub fn offset(&self) -> usize {
60 usize::try_from(self.offset).vortex_expect("offset is a valid usize")
61 }
62
63 #[inline]
64 pub fn indices_dtype(&self) -> DType {
65 assert!(
66 self.indices_ptype().is_unsigned_int(),
67 "Patch indices must be unsigned integers"
68 );
69 DType::Primitive(self.indices_ptype(), NonNullable)
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct Patches {
76 array_len: usize,
77 offset: usize,
78 indices: ArrayRef,
79 values: ArrayRef,
80}
81
82impl Patches {
83 pub fn new(array_len: usize, offset: usize, indices: ArrayRef, values: ArrayRef) -> Self {
84 assert_eq!(
85 indices.len(),
86 values.len(),
87 "Patch indices and values must have the same length"
88 );
89 assert!(
90 indices.dtype().is_unsigned_int(),
91 "Patch indices must be unsigned integers"
92 );
93 assert!(
94 indices.len() <= array_len,
95 "Patch indices must be shorter than the array length"
96 );
97 assert!(!indices.is_empty(), "Patch indices must not be empty");
98 let max = usize::try_from(
99 &indices
100 .scalar_at(indices.len() - 1)
101 .vortex_expect("indices are not empty"),
102 )
103 .vortex_expect("indices must be a number");
104 assert!(
105 max - offset < array_len,
106 "Patch indices {max:?}, offset {offset} are longer than the array length {array_len}"
107 );
108 Self::new_unchecked(array_len, offset, indices, values)
109 }
110
111 pub fn new_unchecked(
121 array_len: usize,
122 offset: usize,
123 indices: ArrayRef,
124 values: ArrayRef,
125 ) -> Self {
126 Self {
127 array_len,
128 offset,
129 indices,
130 values,
131 }
132 }
133
134 pub fn array_len(&self) -> usize {
135 self.array_len
136 }
137
138 pub fn num_patches(&self) -> usize {
139 self.indices.len()
140 }
141
142 pub fn dtype(&self) -> &DType {
143 self.values.dtype()
144 }
145
146 pub fn indices(&self) -> &ArrayRef {
147 &self.indices
148 }
149
150 pub fn into_indices(self) -> ArrayRef {
151 self.indices
152 }
153
154 pub fn indices_mut(&mut self) -> &mut ArrayRef {
155 &mut self.indices
156 }
157
158 pub fn values(&self) -> &ArrayRef {
159 &self.values
160 }
161
162 pub fn into_values(self) -> ArrayRef {
163 self.values
164 }
165
166 pub fn values_mut(&mut self) -> &mut ArrayRef {
167 &mut self.values
168 }
169
170 pub fn offset(&self) -> usize {
171 self.offset
172 }
173
174 pub fn indices_ptype(&self) -> PType {
175 PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
176 }
177
178 pub fn to_metadata(&self, len: usize, dtype: &DType) -> VortexResult<PatchesMetadata> {
179 if self.indices.len() > len {
180 vortex_bail!(
181 "Patch indices {} are longer than the array length {}",
182 self.indices.len(),
183 len
184 );
185 }
186 if self.values.dtype() != dtype {
187 vortex_bail!(
188 "Patch values dtype {} does not match array dtype {}",
189 self.values.dtype(),
190 dtype
191 );
192 }
193 Ok(PatchesMetadata {
194 len: self.indices.len() as u64,
195 offset: self.offset as u64,
196 indices_ptype: PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
197 as i32,
198 })
199 }
200
201 pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
202 Ok(Self::new_unchecked(
203 self.array_len,
204 self.offset,
205 self.indices,
206 cast(&self.values, values_dtype)?,
207 ))
208 }
209
210 pub fn get_patched(&self, index: usize) -> VortexResult<Option<Scalar>> {
212 if let Some(patch_idx) = self.search_index(index)?.to_found() {
213 self.values().scalar_at(patch_idx).map(Some)
214 } else {
215 Ok(None)
216 }
217 }
218
219 pub fn search_index(&self, index: usize) -> VortexResult<SearchResult> {
221 Ok(self.indices.as_primitive_typed().search_sorted(
222 &PValue::U64((index + self.offset) as u64),
223 SearchSortedSide::Left,
224 ))
225 }
226
227 pub fn search_sorted<T: Into<Scalar>>(
229 &self,
230 target: T,
231 side: SearchSortedSide,
232 ) -> VortexResult<SearchResult> {
233 let target = target.into();
234
235 let sr = if self.values().dtype().is_primitive() {
236 self.values()
237 .as_primitive_typed()
238 .search_sorted(&target.as_primitive().pvalue(), side)
239 } else {
240 self.values().search_sorted(&target, side)
241 };
242
243 let index_idx = sr.to_offsets_index(self.indices().len(), side);
244 let index = usize::try_from(&self.indices().scalar_at(index_idx)?)? - self.offset;
245 Ok(match sr {
246 SearchResult::Found(i) => SearchResult::Found(
248 if i == self.indices().len() || side == SearchSortedSide::Right {
249 index + 1
250 } else {
251 index
252 },
253 ),
254 SearchResult::NotFound(i) => {
256 SearchResult::NotFound(if i == 0 { index } else { index + 1 })
257 }
258 })
259 }
260
261 pub fn min_index(&self) -> VortexResult<usize> {
263 Ok(usize::try_from(&self.indices().scalar_at(0)?)? - self.offset)
264 }
265
266 pub fn max_index(&self) -> VortexResult<usize> {
268 Ok(usize::try_from(&self.indices().scalar_at(self.indices().len() - 1)?)? - self.offset)
269 }
270
271 pub fn filter(&self, mask: &Mask) -> VortexResult<Option<Self>> {
273 match mask.indices() {
274 AllOr::All => Ok(Some(self.clone())),
275 AllOr::None => Ok(None),
276 AllOr::Some(mask_indices) => {
277 let flat_indices = self.indices().to_primitive()?;
278 match_each_unsigned_integer_ptype!(flat_indices.ptype(), |I| {
279 filter_patches_with_mask(
280 flat_indices.as_slice::<I>(),
281 self.offset(),
282 self.values(),
283 mask_indices,
284 )
285 })
286 }
287 }
288 }
289
290 pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Option<Self>> {
292 let patch_start = self.search_index(start)?.to_index();
293 let patch_stop = self.search_index(stop)?.to_index();
294
295 if patch_start == patch_stop {
296 return Ok(None);
297 }
298
299 let values = self.values().slice(patch_start, patch_stop)?;
301 let indices = self.indices().slice(patch_start, patch_stop)?;
302
303 Ok(Some(Self::new(
304 stop - start,
305 start + self.offset(),
306 indices,
307 values,
308 )))
309 }
310
311 const PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN: f64 = 5.0;
313
314 fn is_map_faster_than_search(&self, take_indices: &PrimitiveArray) -> bool {
315 (self.num_patches() as f64 / take_indices.len() as f64)
316 < Self::PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN
317 }
318
319 pub fn take_with_nulls(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
323 if take_indices.is_empty() {
324 return Ok(None);
325 }
326
327 let take_indices = take_indices.to_primitive()?;
328 if self.is_map_faster_than_search(&take_indices) {
329 self.take_map(take_indices, true)
330 } else {
331 self.take_search(take_indices, true)
332 }
333 }
334
335 pub fn take(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
339 if take_indices.is_empty() {
340 return Ok(None);
341 }
342
343 let take_indices = take_indices.to_primitive()?;
344 if self.is_map_faster_than_search(&take_indices) {
345 self.take_map(take_indices, false)
346 } else {
347 self.take_search(take_indices, false)
348 }
349 }
350
351 pub fn take_search(
352 &self,
353 take_indices: PrimitiveArray,
354 include_nulls: bool,
355 ) -> VortexResult<Option<Self>> {
356 let indices = self.indices.to_primitive()?;
357 let new_length = take_indices.len();
358
359 let Some((new_indices, values_indices)) =
360 match_each_unsigned_integer_ptype!(indices.ptype(), |Indices| {
361 match_each_integer_ptype!(take_indices.ptype(), |TakeIndices| {
362 take_search::<_, TakeIndices>(
363 indices.as_slice::<Indices>(),
364 take_indices,
365 self.offset(),
366 include_nulls,
367 )?
368 })
369 })
370 else {
371 return Ok(None);
372 };
373
374 Ok(Some(Self::new(
375 new_length,
376 0,
377 new_indices,
378 take(self.values(), &values_indices)?,
379 )))
380 }
381
382 pub fn take_map(
383 &self,
384 take_indices: PrimitiveArray,
385 include_nulls: bool,
386 ) -> VortexResult<Option<Self>> {
387 let indices = self.indices.to_primitive()?;
388 let new_length = take_indices.len();
389
390 let Some((new_sparse_indices, value_indices)) =
391 match_each_unsigned_integer_ptype!(indices.ptype(), |Indices| {
392 match_each_integer_ptype!(take_indices.ptype(), |TakeIndices| {
393 take_map::<_, TakeIndices>(
394 indices.as_slice::<Indices>(),
395 take_indices,
396 self.offset(),
397 self.min_index()?,
398 self.max_index()?,
399 include_nulls,
400 )?
401 })
402 })
403 else {
404 return Ok(None);
405 };
406
407 Ok(Some(Patches::new(
408 new_length,
409 0,
410 new_sparse_indices,
411 take(self.values(), &value_indices)?,
412 )))
413 }
414
415 pub fn map_values<F>(self, f: F) -> VortexResult<Self>
416 where
417 F: FnOnce(ArrayRef) -> VortexResult<ArrayRef>,
418 {
419 let values = f(self.values)?;
420 if self.indices.len() != values.len() {
421 vortex_bail!(
422 "map_values must preserve length: expected {} received {}",
423 self.indices.len(),
424 values.len()
425 )
426 }
427 Ok(Self::new(self.array_len, self.offset, self.indices, values))
428 }
429}
430
431fn take_search<I: NativePType + NumCast + PartialOrd, T: NativePType + NumCast>(
432 indices: &[I],
433 take_indices: PrimitiveArray,
434 indices_offset: usize,
435 include_nulls: bool,
436) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
437where
438 usize: TryFrom<T>,
439 VortexError: From<<usize as TryFrom<T>>::Error>,
440{
441 let take_indices_validity = take_indices.validity();
442 let indices_offset = I::from(indices_offset).vortex_expect("indices_offset out of range");
443
444 let (values_indices, new_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
445 .as_slice::<T>()
446 .iter()
447 .enumerate()
448 .filter_map(|(i, &v)| {
449 I::from(v)
450 .and_then(|v| {
451 if include_nulls && take_indices_validity.is_null(i).vortex_unwrap() {
453 Some(0)
454 } else {
455 indices
456 .search_sorted(&(v + indices_offset), SearchSortedSide::Left)
457 .to_found()
458 .map(|patch_idx| patch_idx as u64)
459 }
460 })
461 .map(|patch_idx| (patch_idx, i as u64))
462 })
463 .unzip();
464
465 if new_indices.is_empty() {
466 return Ok(None);
467 }
468
469 let new_indices = new_indices.into_array();
470 let values_validity = take_indices_validity.take(&new_indices)?;
471 Ok(Some((
472 new_indices,
473 PrimitiveArray::new(values_indices, values_validity).into_array(),
474 )))
475}
476
477fn take_map<I: NativePType + Hash + Eq + TryFrom<usize>, T: NativePType>(
478 indices: &[I],
479 take_indices: PrimitiveArray,
480 indices_offset: usize,
481 min_index: usize,
482 max_index: usize,
483 include_nulls: bool,
484) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
485where
486 usize: TryFrom<T>,
487 VortexError: From<<I as TryFrom<usize>>::Error>,
488{
489 let take_indices_validity = take_indices.validity();
490 let take_indices = take_indices.as_slice::<T>();
491 let offset_i = I::try_from(indices_offset)?;
492
493 let sparse_index_to_value_index: HashMap<I, usize> = indices
494 .iter()
495 .copied()
496 .map(|idx| idx - offset_i)
497 .enumerate()
498 .map(|(value_index, sparse_index)| (sparse_index, value_index))
499 .collect();
500
501 let (new_sparse_indices, value_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
502 .iter()
503 .copied()
504 .map(usize::try_from)
505 .process_results(|iter| {
506 iter.enumerate()
507 .filter_map(|(idx_in_take, ti)| {
508 if include_nulls && take_indices_validity.is_null(idx_in_take).vortex_unwrap() {
510 Some((idx_in_take as u64, 0))
511 } else if ti < min_index || ti > max_index {
512 None
513 } else {
514 sparse_index_to_value_index
515 .get(
516 &I::try_from(ti)
517 .vortex_expect("take index is between min and max index"),
518 )
519 .map(|value_index| (idx_in_take as u64, *value_index as u64))
520 }
521 })
522 .unzip()
523 })
524 .map_err(|_| vortex_err!("Failed to convert index to usize"))?;
525
526 if new_sparse_indices.is_empty() {
527 return Ok(None);
528 }
529
530 let new_sparse_indices = new_sparse_indices.into_array();
531 let values_validity = take_indices_validity.take(&new_sparse_indices)?;
532 Ok(Some((
533 new_sparse_indices,
534 PrimitiveArray::new(value_indices, values_validity).into_array(),
535 )))
536}
537
538fn filter_patches_with_mask<T: ToPrimitive + Copy + Ord>(
544 patch_indices: &[T],
545 offset: usize,
546 patch_values: &dyn Array,
547 mask_indices: &[usize],
548) -> VortexResult<Option<Patches>> {
549 let true_count = mask_indices.len();
550 let mut new_patch_indices = BufferMut::<u64>::with_capacity(true_count);
551 let mut new_mask_indices = Vec::with_capacity(true_count);
552
553 const STRIDE: usize = 4;
557
558 let mut mask_idx = 0usize;
559 let mut true_idx = 0usize;
560
561 while mask_idx < patch_indices.len() && true_idx < true_count {
562 if (mask_idx + STRIDE) < patch_indices.len() && (true_idx + STRIDE) < mask_indices.len() {
569 let left_min = patch_indices[mask_idx].to_usize().vortex_expect("left_min") - offset;
571 let left_max = patch_indices[mask_idx + STRIDE]
572 .to_usize()
573 .vortex_expect("left_max")
574 - offset;
575 let right_min = mask_indices[true_idx];
576 let right_max = mask_indices[true_idx + STRIDE];
577
578 if left_min > right_max {
579 true_idx += STRIDE;
581 continue;
582 } else if right_min > left_max {
583 mask_idx += STRIDE;
584 continue;
585 } else {
586 }
588 }
589
590 let left = patch_indices[mask_idx].to_usize().vortex_expect("left") - offset;
593 let right = mask_indices[true_idx];
594
595 match left.cmp(&right) {
596 Ordering::Less => {
597 mask_idx += 1;
598 }
599 Ordering::Greater => {
600 true_idx += 1;
601 }
602 Ordering::Equal => {
603 new_mask_indices.push(mask_idx);
605 new_patch_indices.push(true_idx as u64);
606
607 mask_idx += 1;
608 true_idx += 1;
609 }
610 }
611 }
612
613 if new_mask_indices.is_empty() {
614 return Ok(None);
615 }
616
617 let new_patch_indices = new_patch_indices.into_array();
618 let new_patch_values = filter(
619 patch_values,
620 &Mask::from_indices(patch_values.len(), new_mask_indices),
621 )?;
622
623 Ok(Some(Patches::new(
624 true_count,
625 0,
626 new_patch_indices,
627 new_patch_values,
628 )))
629}
630
631#[cfg(test)]
632mod test {
633 use rstest::{fixture, rstest};
634 use vortex_buffer::buffer;
635 use vortex_mask::Mask;
636
637 use crate::arrays::PrimitiveArray;
638 use crate::patches::Patches;
639 use crate::search_sorted::{SearchResult, SearchSortedSide};
640 use crate::validity::Validity;
641 use crate::{IntoArray, ToCanonical};
642
643 #[test]
644 fn test_filter() {
645 let patches = Patches::new(
646 100,
647 0,
648 buffer![10u32, 11, 20].into_array(),
649 buffer![100, 110, 200].into_array(),
650 );
651
652 let filtered = patches
653 .filter(&Mask::from_indices(100, vec![10, 20, 30]))
654 .unwrap()
655 .unwrap();
656
657 let indices = filtered.indices().to_primitive().unwrap();
658 let values = filtered.values().to_primitive().unwrap();
659 assert_eq!(indices.as_slice::<u64>(), &[0, 1]);
660 assert_eq!(values.as_slice::<i32>(), &[100, 200]);
661 }
662
663 #[fixture]
664 fn patches() -> Patches {
665 Patches::new(
666 20,
667 0,
668 buffer![2u64, 9, 15].into_array(),
669 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
670 )
671 }
672
673 #[rstest]
674 fn search_larger_than(patches: Patches) {
675 let res = patches.search_sorted(66, SearchSortedSide::Left).unwrap();
676 assert_eq!(res, SearchResult::NotFound(16));
677 }
678
679 #[rstest]
680 fn search_less_than(patches: Patches) {
681 let res = patches.search_sorted(22, SearchSortedSide::Left).unwrap();
682 assert_eq!(res, SearchResult::NotFound(2));
683 }
684
685 #[rstest]
686 fn search_found(patches: Patches) {
687 let res = patches.search_sorted(44, SearchSortedSide::Left).unwrap();
688 assert_eq!(res, SearchResult::Found(9));
689 }
690
691 #[rstest]
692 fn search_not_found_right(patches: Patches) {
693 let res = patches.search_sorted(56, SearchSortedSide::Right).unwrap();
694 assert_eq!(res, SearchResult::NotFound(16));
695 }
696
697 #[rstest]
698 fn search_sliced(patches: Patches) {
699 let sliced = patches.slice(7, 20).unwrap().unwrap();
700 assert_eq!(
701 sliced.search_sorted(22, SearchSortedSide::Left).unwrap(),
702 SearchResult::NotFound(2)
703 );
704 }
705
706 #[test]
707 fn search_right() {
708 let patches = Patches::new(
709 6,
710 0,
711 buffer![0u8, 1, 4, 5].into_array(),
712 buffer![-128i8, -98, 8, 50].into_array(),
713 );
714
715 assert_eq!(
716 patches.search_sorted(-98, SearchSortedSide::Right).unwrap(),
717 SearchResult::Found(2)
718 );
719 assert_eq!(
720 patches.search_sorted(50, SearchSortedSide::Right).unwrap(),
721 SearchResult::Found(6),
722 );
723 assert_eq!(
724 patches.search_sorted(7, SearchSortedSide::Right).unwrap(),
725 SearchResult::NotFound(2),
726 );
727 assert_eq!(
728 patches.search_sorted(51, SearchSortedSide::Right).unwrap(),
729 SearchResult::NotFound(6)
730 );
731 }
732
733 #[test]
734 fn search_left() {
735 let patches = Patches::new(
736 20,
737 0,
738 buffer![0u64, 1, 17, 18, 19].into_array(),
739 buffer![11i32, 22, 33, 44, 55].into_array(),
740 );
741 assert_eq!(
742 patches.search_sorted(30, SearchSortedSide::Left).unwrap(),
743 SearchResult::NotFound(2)
744 );
745 assert_eq!(
746 patches.search_sorted(54, SearchSortedSide::Left).unwrap(),
747 SearchResult::NotFound(19)
748 );
749 }
750
751 #[rstest]
752 fn take_wit_nulls(patches: Patches) {
753 let taken = patches
754 .take(
755 &PrimitiveArray::new(buffer![9, 0], Validity::from_iter(vec![true, false]))
756 .into_array(),
757 )
758 .unwrap()
759 .unwrap();
760 let primitive_values = taken.values().to_primitive().unwrap();
761 assert_eq!(taken.array_len(), 2);
762 assert_eq!(primitive_values.as_slice::<i32>(), [44]);
763 assert_eq!(
764 primitive_values.validity_mask().unwrap(),
765 Mask::from_iter(vec![true])
766 );
767 }
768
769 #[test]
770 fn test_slice() {
771 let values = buffer![15_u32, 135, 13531, 42].into_array();
772 let indices = buffer![10_u64, 11, 50, 100].into_array();
773
774 let patches = Patches::new(101, 0, indices, values);
775
776 let sliced = patches.slice(15, 100).unwrap().unwrap();
777 assert_eq!(sliced.array_len(), 100 - 15);
778 let primitive = sliced.values().to_primitive().unwrap();
779
780 assert_eq!(primitive.as_slice::<u32>(), &[13531]);
781 }
782
783 #[test]
784 fn doubly_sliced() {
785 let values = buffer![15_u32, 135, 13531, 42].into_array();
786 let indices = buffer![10_u64, 11, 50, 100].into_array();
787
788 let patches = Patches::new(101, 0, indices, values);
789
790 let sliced = patches.slice(15, 100).unwrap().unwrap();
791 assert_eq!(sliced.array_len(), 100 - 15);
792 let primitive = sliced.values().to_primitive().unwrap();
793
794 assert_eq!(primitive.as_slice::<u32>(), &[13531]);
795
796 let doubly_sliced = sliced.slice(35, 36).unwrap().unwrap();
797 let primitive_doubly_sliced = doubly_sliced.values().to_primitive().unwrap();
798
799 assert_eq!(primitive_doubly_sliced.as_slice::<u32>(), &[13531]);
800 }
801}