1use std::cmp::Ordering;
2use std::fmt::Debug;
3use std::hash::Hash;
4
5use itertools::Itertools as _;
6use num_traits::{NumCast, ToPrimitive};
7use serde::{Deserialize, Serialize};
8use vortex_buffer::BufferMut;
9use vortex_dtype::Nullability::NonNullable;
10use vortex_dtype::{DType, NativePType, PType, match_each_integer_ptype};
11use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
12use vortex_mask::{AllOr, Mask};
13use vortex_scalar::{PValue, Scalar};
14use vortex_utils::aliases::hash_map::HashMap;
15
16use crate::arrays::PrimitiveArray;
17use crate::compute::{cast, filter, take};
18use crate::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
19use crate::vtable::ValidityHelper;
20use crate::{Array, ArrayRef, IntoArray, ToCanonical};
21
22#[derive(Copy, Clone, Serialize, Deserialize, prost::Message)]
23pub struct PatchesMetadata {
24 #[prost(uint64, tag = "1")]
25 len: u64,
26 #[prost(uint64, tag = "2")]
27 offset: u64,
28 #[prost(enumeration = "PType", tag = "3")]
29 indices_ptype: i32,
30}
31
32impl PatchesMetadata {
33 pub fn new(len: usize, offset: usize, indices_ptype: PType) -> Self {
34 Self {
35 len: len as u64,
36 offset: offset as u64,
37 indices_ptype: indices_ptype as i32,
38 }
39 }
40
41 #[inline]
42 pub fn len(&self) -> usize {
43 usize::try_from(self.len).vortex_expect("len is a valid usize")
44 }
45
46 #[inline]
47 pub fn is_empty(&self) -> bool {
48 self.len == 0
49 }
50
51 #[inline]
52 pub fn offset(&self) -> usize {
53 usize::try_from(self.offset).vortex_expect("offset is a valid usize")
54 }
55
56 #[inline]
57 pub fn indices_dtype(&self) -> DType {
58 assert!(
59 self.indices_ptype().is_unsigned_int(),
60 "Patch indices must be unsigned integers"
61 );
62 DType::Primitive(self.indices_ptype(), NonNullable)
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct Patches {
69 array_len: usize,
70 offset: usize,
71 indices: ArrayRef,
72 values: ArrayRef,
73}
74
75impl Patches {
76 pub fn new(array_len: usize, offset: usize, indices: ArrayRef, values: ArrayRef) -> Self {
77 assert_eq!(
78 indices.len(),
79 values.len(),
80 "Patch indices and values must have the same length"
81 );
82 assert!(
83 indices.dtype().is_unsigned_int(),
84 "Patch indices must be unsigned integers"
85 );
86 assert!(
87 indices.len() <= array_len,
88 "Patch indices must be shorter than the array length"
89 );
90 assert!(!indices.is_empty(), "Patch indices must not be empty");
91 let max = usize::try_from(
92 &indices
93 .scalar_at(indices.len() - 1)
94 .vortex_expect("indices are not empty"),
95 )
96 .vortex_expect("indices must be a number");
97 assert!(
98 max - offset < array_len,
99 "Patch indices {max:?}, offset {offset} are longer than the array length {array_len}"
100 );
101 Self::new_unchecked(array_len, offset, indices, values)
102 }
103
104 pub fn new_unchecked(
114 array_len: usize,
115 offset: usize,
116 indices: ArrayRef,
117 values: ArrayRef,
118 ) -> Self {
119 Self {
120 array_len,
121 offset,
122 indices,
123 values,
124 }
125 }
126
127 pub fn into_parts(self) -> (usize, usize, ArrayRef, ArrayRef) {
129 (self.array_len, self.offset, self.indices, self.values)
130 }
131
132 pub fn array_len(&self) -> usize {
133 self.array_len
134 }
135
136 pub fn num_patches(&self) -> usize {
137 self.indices.len()
138 }
139
140 pub fn dtype(&self) -> &DType {
141 self.values.dtype()
142 }
143
144 pub fn indices(&self) -> &ArrayRef {
145 &self.indices
146 }
147
148 pub fn into_indices(self) -> ArrayRef {
149 self.indices
150 }
151
152 pub fn indices_mut(&mut self) -> &mut ArrayRef {
153 &mut self.indices
154 }
155
156 pub fn values(&self) -> &ArrayRef {
157 &self.values
158 }
159
160 pub fn into_values(self) -> ArrayRef {
161 self.values
162 }
163
164 pub fn values_mut(&mut self) -> &mut ArrayRef {
165 &mut self.values
166 }
167
168 pub fn offset(&self) -> usize {
169 self.offset
170 }
171
172 pub fn indices_ptype(&self) -> PType {
173 PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
174 }
175
176 pub fn to_metadata(&self, len: usize, dtype: &DType) -> VortexResult<PatchesMetadata> {
177 if self.indices.len() > len {
178 vortex_bail!(
179 "Patch indices {} are longer than the array length {}",
180 self.indices.len(),
181 len
182 );
183 }
184 if self.values.dtype() != dtype {
185 vortex_bail!(
186 "Patch values dtype {} does not match array dtype {}",
187 self.values.dtype(),
188 dtype
189 );
190 }
191 Ok(PatchesMetadata {
192 len: self.indices.len() as u64,
193 offset: self.offset as u64,
194 indices_ptype: PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
195 as i32,
196 })
197 }
198
199 pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
200 Ok(Self::new_unchecked(
201 self.array_len,
202 self.offset,
203 self.indices,
204 cast(&self.values, values_dtype)?,
205 ))
206 }
207
208 pub fn get_patched(&self, index: usize) -> VortexResult<Option<Scalar>> {
210 if let Some(patch_idx) = self.search_index(index)?.to_found() {
211 self.values().scalar_at(patch_idx).map(Some)
212 } else {
213 Ok(None)
214 }
215 }
216
217 pub fn search_index(&self, index: usize) -> VortexResult<SearchResult> {
219 Ok(self.indices.as_primitive_typed().search_sorted(
220 &PValue::U64((index + self.offset) as u64),
221 SearchSortedSide::Left,
222 ))
223 }
224
225 pub fn search_sorted<T: Into<Scalar>>(
227 &self,
228 target: T,
229 side: SearchSortedSide,
230 ) -> VortexResult<SearchResult> {
231 let target = target.into();
232
233 let sr = if self.values().dtype().is_primitive() {
234 self.values()
235 .as_primitive_typed()
236 .search_sorted(&target.as_primitive().pvalue(), side)
237 } else {
238 self.values().search_sorted(&target, side)
239 };
240
241 let index_idx = sr.to_offsets_index(self.indices().len(), side);
242 let index = usize::try_from(&self.indices().scalar_at(index_idx)?)? - self.offset;
243 Ok(match sr {
244 SearchResult::Found(i) => SearchResult::Found(
246 if i == self.indices().len() || side == SearchSortedSide::Right {
247 index + 1
248 } else {
249 index
250 },
251 ),
252 SearchResult::NotFound(i) => {
254 SearchResult::NotFound(if i == 0 { index } else { index + 1 })
255 }
256 })
257 }
258
259 pub fn min_index(&self) -> VortexResult<usize> {
261 Ok(usize::try_from(&self.indices().scalar_at(0)?)? - self.offset)
262 }
263
264 pub fn max_index(&self) -> VortexResult<usize> {
266 Ok(usize::try_from(&self.indices().scalar_at(self.indices().len() - 1)?)? - self.offset)
267 }
268
269 pub fn filter(&self, mask: &Mask) -> VortexResult<Option<Self>> {
271 match mask.indices() {
272 AllOr::All => Ok(Some(self.clone())),
273 AllOr::None => Ok(None),
274 AllOr::Some(mask_indices) => {
275 let flat_indices = self.indices().to_primitive()?;
276 match_each_integer_ptype!(flat_indices.ptype(), |I| {
277 filter_patches_with_mask(
278 flat_indices.as_slice::<I>(),
279 self.offset(),
280 self.values(),
281 mask_indices,
282 )
283 })
284 }
285 }
286 }
287
288 pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Option<Self>> {
290 let patch_start = self.search_index(start)?.to_index();
291 let patch_stop = self.search_index(stop)?.to_index();
292
293 if patch_start == patch_stop {
294 return Ok(None);
295 }
296
297 let values = self.values().slice(patch_start, patch_stop)?;
299 let indices = self.indices().slice(patch_start, patch_stop)?;
300
301 Ok(Some(Self::new(
302 stop - start,
303 start + self.offset(),
304 indices,
305 values,
306 )))
307 }
308
309 const PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN: f64 = 5.0;
311
312 fn is_map_faster_than_search(&self, take_indices: &PrimitiveArray) -> bool {
313 (self.num_patches() as f64 / take_indices.len() as f64)
314 < Self::PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN
315 }
316
317 pub fn take(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
319 if take_indices.is_empty() {
320 return Ok(None);
321 }
322 let take_indices = take_indices.to_primitive()?;
323 if self.is_map_faster_than_search(&take_indices) {
324 self.take_map(take_indices)
325 } else {
326 self.take_search(take_indices)
327 }
328 }
329
330 pub fn take_search(&self, take_indices: PrimitiveArray) -> VortexResult<Option<Self>> {
331 let indices = self.indices.to_primitive()?;
332 let new_length = take_indices.len();
333
334 let Some((new_indices, values_indices)) =
335 match_each_integer_ptype!(indices.ptype(), |Indices| {
336 match_each_integer_ptype!(take_indices.ptype(), |TakeIndices| {
337 take_search::<_, TakeIndices>(
338 indices.as_slice::<Indices>(),
339 take_indices,
340 self.offset(),
341 )?
342 })
343 })
344 else {
345 return Ok(None);
346 };
347
348 Ok(Some(Self::new(
349 new_length,
350 0,
351 new_indices,
352 take(self.values(), &values_indices)?,
353 )))
354 }
355
356 pub fn take_map(&self, take_indices: PrimitiveArray) -> VortexResult<Option<Self>> {
357 let indices = self.indices.to_primitive()?;
358 let new_length = take_indices.len();
359
360 let Some((new_sparse_indices, value_indices)) =
361 match_each_integer_ptype!(self.indices_ptype(), |Indices| {
362 match_each_integer_ptype!(take_indices.ptype(), |TakeIndices| {
363 take_map::<_, TakeIndices>(
364 indices.as_slice::<Indices>(),
365 take_indices,
366 self.offset(),
367 self.min_index()?,
368 self.max_index()?,
369 )?
370 })
371 })
372 else {
373 return Ok(None);
374 };
375
376 Ok(Some(Patches::new(
377 new_length,
378 0,
379 new_sparse_indices,
380 take(self.values(), &value_indices)?,
381 )))
382 }
383
384 pub fn map_values<F>(self, f: F) -> VortexResult<Self>
385 where
386 F: FnOnce(ArrayRef) -> VortexResult<ArrayRef>,
387 {
388 let values = f(self.values)?;
389 if self.indices.len() != values.len() {
390 vortex_bail!(
391 "map_values must preserve length: expected {} received {}",
392 self.indices.len(),
393 values.len()
394 )
395 }
396 Ok(Self::new(self.array_len, self.offset, self.indices, values))
397 }
398}
399
400fn take_search<I: NativePType + NumCast + PartialOrd, T: NativePType + NumCast>(
401 indices: &[I],
402 take_indices: PrimitiveArray,
403 indices_offset: usize,
404) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
405where
406 usize: TryFrom<T>,
407 VortexError: From<<usize as TryFrom<T>>::Error>,
408{
409 let take_indices_validity = take_indices.validity();
410 let indices_offset = I::from(indices_offset).vortex_expect("indices_offset out of range");
411
412 let (values_indices, new_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
413 .as_slice::<T>()
414 .iter()
415 .map(|v| {
416 match I::from(*v) {
417 None => {
418 SearchResult::NotFound(indices.len())
420 }
421 Some(v) => indices.search_sorted(&(v + indices_offset), SearchSortedSide::Left),
422 }
423 })
424 .enumerate()
425 .filter_map(|(idx_in_take, search_result)| {
426 search_result
427 .to_found()
428 .map(|patch_idx| (patch_idx as u64, idx_in_take as u64))
429 })
430 .unzip();
431
432 if new_indices.is_empty() {
433 return Ok(None);
434 }
435
436 let new_indices = new_indices.into_array();
437 let values_validity = take_indices_validity.take(&new_indices)?;
438 Ok(Some((
439 new_indices,
440 PrimitiveArray::new(values_indices, values_validity).into_array(),
441 )))
442}
443
444fn take_map<I: NativePType + Hash + Eq + TryFrom<usize>, T: NativePType>(
445 indices: &[I],
446 take_indices: PrimitiveArray,
447 indices_offset: usize,
448 min_index: usize,
449 max_index: usize,
450) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
451where
452 usize: TryFrom<T>,
453 VortexError: From<<I as TryFrom<usize>>::Error>,
454{
455 let take_indices_validity = take_indices.validity();
456 let take_indices = take_indices.as_slice::<T>();
457 let offset_i = I::try_from(indices_offset)?;
458
459 let sparse_index_to_value_index: HashMap<I, usize> = indices
460 .iter()
461 .copied()
462 .map(|idx| idx - offset_i)
463 .enumerate()
464 .map(|(value_index, sparse_index)| (sparse_index, value_index))
465 .collect();
466 let (new_sparse_indices, value_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
467 .iter()
468 .copied()
469 .map(usize::try_from)
470 .process_results(|iter| {
471 iter.enumerate()
472 .filter(|(_, ti)| *ti >= min_index && *ti <= max_index)
473 .filter_map(|(new_sparse_index, take_sparse_index)| {
474 sparse_index_to_value_index
475 .get(
476 &I::try_from(take_sparse_index)
477 .vortex_expect("take_sparse_index is between min and max index"),
478 )
479 .map(|value_index| (new_sparse_index as u64, *value_index as u64))
480 })
481 .unzip()
482 })
483 .map_err(|_| vortex_err!("Failed to convert index to usize"))?;
484
485 if new_sparse_indices.is_empty() {
486 return Ok(None);
487 }
488
489 let new_sparse_indices = new_sparse_indices.into_array();
490 let values_validity = take_indices_validity.take(&new_sparse_indices)?;
491 Ok(Some((
492 new_sparse_indices,
493 PrimitiveArray::new(value_indices, values_validity).into_array(),
494 )))
495}
496
497fn filter_patches_with_mask<T: ToPrimitive + Copy + Ord>(
503 patch_indices: &[T],
504 offset: usize,
505 patch_values: &dyn Array,
506 mask_indices: &[usize],
507) -> VortexResult<Option<Patches>> {
508 let true_count = mask_indices.len();
509 let mut new_patch_indices = BufferMut::<u64>::with_capacity(true_count);
510 let mut new_mask_indices = Vec::with_capacity(true_count);
511
512 const STRIDE: usize = 4;
516
517 let mut mask_idx = 0usize;
518 let mut true_idx = 0usize;
519
520 while mask_idx < patch_indices.len() && true_idx < true_count {
521 if (mask_idx + STRIDE) < patch_indices.len() && (true_idx + STRIDE) < mask_indices.len() {
528 let left_min = patch_indices[mask_idx].to_usize().vortex_expect("left_min") - offset;
530 let left_max = patch_indices[mask_idx + STRIDE]
531 .to_usize()
532 .vortex_expect("left_max")
533 - offset;
534 let right_min = mask_indices[true_idx];
535 let right_max = mask_indices[true_idx + STRIDE];
536
537 if left_min > right_max {
538 true_idx += STRIDE;
540 continue;
541 } else if right_min > left_max {
542 mask_idx += STRIDE;
543 continue;
544 } else {
545 }
547 }
548
549 let left = patch_indices[mask_idx].to_usize().vortex_expect("left") - offset;
552 let right = mask_indices[true_idx];
553
554 match left.cmp(&right) {
555 Ordering::Less => {
556 mask_idx += 1;
557 }
558 Ordering::Greater => {
559 true_idx += 1;
560 }
561 Ordering::Equal => {
562 new_mask_indices.push(mask_idx);
564 new_patch_indices.push(true_idx as u64);
565
566 mask_idx += 1;
567 true_idx += 1;
568 }
569 }
570 }
571
572 if new_mask_indices.is_empty() {
573 return Ok(None);
574 }
575
576 let new_patch_indices = new_patch_indices.into_array();
577 let new_patch_values = filter(
578 patch_values,
579 &Mask::from_indices(patch_values.len(), new_mask_indices),
580 )?;
581
582 Ok(Some(Patches::new(
583 true_count,
584 0,
585 new_patch_indices,
586 new_patch_values,
587 )))
588}
589
590#[cfg(test)]
591mod test {
592 use rstest::{fixture, rstest};
593 use vortex_buffer::buffer;
594 use vortex_mask::Mask;
595
596 use crate::arrays::PrimitiveArray;
597 use crate::patches::Patches;
598 use crate::search_sorted::{SearchResult, SearchSortedSide};
599 use crate::validity::Validity;
600 use crate::{IntoArray, ToCanonical};
601
602 #[test]
603 fn test_filter() {
604 let patches = Patches::new(
605 100,
606 0,
607 buffer![10u32, 11, 20].into_array(),
608 buffer![100, 110, 200].into_array(),
609 );
610
611 let filtered = patches
612 .filter(&Mask::from_indices(100, vec![10, 20, 30]))
613 .unwrap()
614 .unwrap();
615
616 let indices = filtered.indices().to_primitive().unwrap();
617 let values = filtered.values().to_primitive().unwrap();
618 assert_eq!(indices.as_slice::<u64>(), &[0, 1]);
619 assert_eq!(values.as_slice::<i32>(), &[100, 200]);
620 }
621
622 #[fixture]
623 fn patches() -> Patches {
624 Patches::new(
625 20,
626 0,
627 buffer![2u64, 9, 15].into_array(),
628 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
629 )
630 }
631
632 #[rstest]
633 fn search_larger_than(patches: Patches) {
634 let res = patches.search_sorted(66, SearchSortedSide::Left).unwrap();
635 assert_eq!(res, SearchResult::NotFound(16));
636 }
637
638 #[rstest]
639 fn search_less_than(patches: Patches) {
640 let res = patches.search_sorted(22, SearchSortedSide::Left).unwrap();
641 assert_eq!(res, SearchResult::NotFound(2));
642 }
643
644 #[rstest]
645 fn search_found(patches: Patches) {
646 let res = patches.search_sorted(44, SearchSortedSide::Left).unwrap();
647 assert_eq!(res, SearchResult::Found(9));
648 }
649
650 #[rstest]
651 fn search_not_found_right(patches: Patches) {
652 let res = patches.search_sorted(56, SearchSortedSide::Right).unwrap();
653 assert_eq!(res, SearchResult::NotFound(16));
654 }
655
656 #[rstest]
657 fn search_sliced(patches: Patches) {
658 let sliced = patches.slice(7, 20).unwrap().unwrap();
659 assert_eq!(
660 sliced.search_sorted(22, SearchSortedSide::Left).unwrap(),
661 SearchResult::NotFound(2)
662 );
663 }
664
665 #[test]
666 fn search_right() {
667 let patches = Patches::new(
668 6,
669 0,
670 buffer![0u8, 1, 4, 5].into_array(),
671 buffer![-128i8, -98, 8, 50].into_array(),
672 );
673
674 assert_eq!(
675 patches.search_sorted(-98, SearchSortedSide::Right).unwrap(),
676 SearchResult::Found(2)
677 );
678 assert_eq!(
679 patches.search_sorted(50, SearchSortedSide::Right).unwrap(),
680 SearchResult::Found(6),
681 );
682 assert_eq!(
683 patches.search_sorted(7, SearchSortedSide::Right).unwrap(),
684 SearchResult::NotFound(2),
685 );
686 assert_eq!(
687 patches.search_sorted(51, SearchSortedSide::Right).unwrap(),
688 SearchResult::NotFound(6)
689 );
690 }
691
692 #[test]
693 fn search_left() {
694 let patches = Patches::new(
695 20,
696 0,
697 buffer![0u64, 1, 17, 18, 19].into_array(),
698 buffer![11i32, 22, 33, 44, 55].into_array(),
699 );
700 assert_eq!(
701 patches.search_sorted(30, SearchSortedSide::Left).unwrap(),
702 SearchResult::NotFound(2)
703 );
704 assert_eq!(
705 patches.search_sorted(54, SearchSortedSide::Left).unwrap(),
706 SearchResult::NotFound(19)
707 );
708 }
709
710 #[rstest]
711 fn take_wit_nulls(patches: Patches) {
712 let taken = patches
713 .take(
714 &PrimitiveArray::new(buffer![9, 0], Validity::from_iter(vec![true, false]))
715 .into_array(),
716 )
717 .unwrap()
718 .unwrap();
719 let primitive_values = taken.values().to_primitive().unwrap();
720 assert_eq!(taken.array_len(), 2);
721 assert_eq!(primitive_values.as_slice::<i32>(), [44]);
722 assert_eq!(
723 primitive_values.validity_mask().unwrap(),
724 Mask::from_iter(vec![true])
725 );
726 }
727
728 #[test]
729 fn test_slice() {
730 let values = buffer![15_u32, 135, 13531, 42].into_array();
731 let indices = buffer![10_u64, 11, 50, 100].into_array();
732
733 let patches = Patches::new(101, 0, indices, values);
734
735 let sliced = patches.slice(15, 100).unwrap().unwrap();
736 assert_eq!(sliced.array_len(), 100 - 15);
737 let primitive = sliced.values().to_primitive().unwrap();
738
739 assert_eq!(primitive.as_slice::<u32>(), &[13531]);
740 }
741
742 #[test]
743 fn doubly_sliced() {
744 let values = buffer![15_u32, 135, 13531, 42].into_array();
745 let indices = buffer![10_u64, 11, 50, 100].into_array();
746
747 let patches = Patches::new(101, 0, indices, values);
748
749 let sliced = patches.slice(15, 100).unwrap().unwrap();
750 assert_eq!(sliced.array_len(), 100 - 15);
751 let primitive = sliced.values().to_primitive().unwrap();
752
753 assert_eq!(primitive.as_slice::<u32>(), &[13531]);
754
755 let doubly_sliced = sliced.slice(35, 36).unwrap().unwrap();
756 let primitive_doubly_sliced = doubly_sliced.values().to_primitive().unwrap();
757
758 assert_eq!(primitive_doubly_sliced.as_slice::<u32>(), &[13531]);
759 }
760}