1use std::cmp::Ordering;
2use std::fmt::Debug;
3use std::hash::Hash;
4
5use itertools::Itertools as _;
6use num_traits::ToPrimitive;
7use serde::{Deserialize, Serialize};
8use vortex_buffer::BufferMut;
9use vortex_dtype::Nullability::NonNullable;
10use vortex_dtype::{DType, NativePType, PType, match_each_integer_ptype};
11use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
12use vortex_mask::{AllOr, Mask};
13use vortex_scalar::Scalar;
14
15use crate::aliases::hash_map::HashMap;
16use crate::arrays::PrimitiveArray;
17use crate::compute::{
18 SearchResult, SearchSortedSide, filter, scalar_at, search_sorted, search_sorted_usize,
19 search_sorted_usize_many, slice, take, try_cast,
20};
21use crate::variants::PrimitiveArrayTrait;
22use crate::{Array, ArrayRef, IntoArray, ToCanonical};
23
24#[derive(
25 Copy,
26 Clone,
27 Debug,
28 Serialize,
29 Deserialize,
30 rkyv::Archive,
31 rkyv::Serialize,
32 rkyv::Deserialize,
33 rkyv::bytecheck::CheckBytes,
34)]
35#[bytecheck(crate = rkyv::bytecheck)]
36#[repr(C)]
37pub struct PatchesMetadata {
38 len: usize,
39 offset: usize,
40 indices_ptype: PType,
41}
42
43impl PatchesMetadata {
44 pub fn new(len: usize, offset: usize, indices_ptype: PType) -> Self {
45 Self {
46 len,
47 offset,
48 indices_ptype,
49 }
50 }
51
52 #[inline]
53 pub fn len(&self) -> usize {
54 self.len
55 }
56
57 #[inline]
58 pub fn is_empty(&self) -> bool {
59 self.len == 0
60 }
61
62 #[inline]
63 pub fn offset(&self) -> usize {
64 self.offset
65 }
66
67 #[inline]
68 pub fn indices_dtype(&self) -> DType {
69 assert!(
70 self.indices_ptype.is_unsigned_int(),
71 "Patch indices must be unsigned integers"
72 );
73 DType::Primitive(self.indices_ptype, NonNullable)
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct Patches {
80 array_len: usize,
81 offset: usize,
82 indices: ArrayRef,
83 values: ArrayRef,
84}
85
86impl Patches {
87 pub fn new(array_len: usize, offset: usize, indices: ArrayRef, values: ArrayRef) -> Self {
88 assert_eq!(
89 indices.len(),
90 values.len(),
91 "Patch indices and values must have the same length"
92 );
93 assert!(
94 indices.dtype().is_unsigned_int(),
95 "Patch indices must be unsigned integers"
96 );
97 assert!(
98 indices.len() <= array_len,
99 "Patch indices must be shorter than the array length"
100 );
101 assert!(!indices.is_empty(), "Patch indices must not be empty");
102 let max = usize::try_from(
103 &scalar_at(&indices, indices.len() - 1).vortex_expect("indices are not empty"),
104 )
105 .vortex_expect("indices must be a number");
106 assert!(
107 max - offset < array_len,
108 "Patch indices {:?}, offset {} are longer than the array length {}",
109 max,
110 offset,
111 array_len
112 );
113 Self::new_unchecked(array_len, offset, indices, values)
114 }
115
116 pub fn new_unchecked(
126 array_len: usize,
127 offset: usize,
128 indices: ArrayRef,
129 values: ArrayRef,
130 ) -> Self {
131 Self {
132 array_len,
133 offset,
134 indices,
135 values,
136 }
137 }
138
139 pub fn into_parts(self) -> (usize, usize, ArrayRef, ArrayRef) {
141 (self.array_len, self.offset, self.indices, self.values)
142 }
143
144 pub fn array_len(&self) -> usize {
145 self.array_len
146 }
147
148 pub fn num_patches(&self) -> usize {
149 self.indices.len()
150 }
151
152 pub fn dtype(&self) -> &DType {
153 self.values.dtype()
154 }
155
156 pub fn indices(&self) -> &ArrayRef {
157 &self.indices
158 }
159
160 pub fn into_indices(self) -> ArrayRef {
161 self.indices
162 }
163
164 pub fn values(&self) -> &ArrayRef {
165 &self.values
166 }
167
168 pub fn into_values(self) -> ArrayRef {
169 self.values
170 }
171
172 pub fn offset(&self) -> usize {
173 self.offset
174 }
175
176 pub fn indices_ptype(&self) -> PType {
177 PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
178 }
179
180 pub fn to_metadata(&self, len: usize, dtype: &DType) -> VortexResult<PatchesMetadata> {
181 if self.indices.len() > len {
182 vortex_bail!(
183 "Patch indices {} are longer than the array length {}",
184 self.indices.len(),
185 len
186 );
187 }
188 if self.values.dtype() != dtype {
189 vortex_bail!(
190 "Patch values dtype {} does not match array dtype {}",
191 self.values.dtype(),
192 dtype
193 );
194 }
195 Ok(PatchesMetadata {
196 len: self.indices.len(),
197 offset: self.offset,
198 indices_ptype: PType::try_from(self.indices.dtype()).vortex_expect("primitive indices"),
199 })
200 }
201
202 pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
203 Ok(Self::new_unchecked(
204 self.array_len,
205 self.offset,
206 self.indices,
207 try_cast(&self.values, values_dtype)?,
208 ))
209 }
210
211 pub fn get_patched(&self, index: usize) -> VortexResult<Option<Scalar>> {
213 if let Some(patch_idx) = self.search_index(index)?.to_found() {
214 scalar_at(self.values(), patch_idx).map(Some)
215 } else {
216 Ok(None)
217 }
218 }
219
220 fn search_index(&self, index: usize) -> VortexResult<SearchResult> {
222 search_sorted_usize(&self.indices, index + self.offset, SearchSortedSide::Left)
223 }
224
225 pub fn search_sorted<T: Into<Scalar>>(
227 &self,
228 target: T,
229 side: SearchSortedSide,
230 ) -> VortexResult<SearchResult> {
231 search_sorted(self.values(), target.into(), side).and_then(|sr| {
232 let sidx = sr.to_offsets_index(self.indices().len());
233 let index = usize::try_from(&scalar_at(self.indices(), sidx)?)? - self.offset;
234 Ok(match sr {
235 SearchResult::Found(i) => SearchResult::Found(if i == self.indices().len() {
237 index + 1
238 } else {
239 index
240 }),
241 SearchResult::NotFound(i) => {
243 SearchResult::NotFound(if i == 0 { index } else { index + 1 })
244 }
245 })
246 })
247 }
248
249 pub fn min_index(&self) -> VortexResult<usize> {
251 Ok(usize::try_from(&scalar_at(self.indices(), 0)?)? - self.offset)
252 }
253
254 pub fn max_index(&self) -> VortexResult<usize> {
256 Ok(usize::try_from(&scalar_at(self.indices(), self.indices().len() - 1)?)? - self.offset)
257 }
258
259 pub fn filter(&self, mask: &Mask) -> VortexResult<Option<Self>> {
261 match mask.indices() {
262 AllOr::All => Ok(Some(self.clone())),
263 AllOr::None => Ok(None),
264 AllOr::Some(mask_indices) => {
265 let flat_indices = self.indices().to_primitive()?;
266 match_each_integer_ptype!(flat_indices.ptype(), |$I| {
267 filter_patches_with_mask(
268 flat_indices.as_slice::<$I>(),
269 self.offset(),
270 self.values(),
271 mask_indices,
272 )
273 })
274 }
275 }
276 }
277
278 pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Option<Self>> {
280 let patch_start = self.search_index(start)?.to_index();
281 let patch_stop = self.search_index(stop)?.to_index();
282
283 if patch_start == patch_stop {
284 return Ok(None);
285 }
286
287 let values = slice(self.values(), patch_start, patch_stop)?;
289 let indices = slice(self.indices(), patch_start, patch_stop)?;
290
291 Ok(Some(Self::new(
292 stop - start,
293 start + self.offset(),
294 indices,
295 values,
296 )))
297 }
298
299 const PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN: f64 = 5.0;
301
302 fn is_map_faster_than_search(&self, take_indices: &PrimitiveArray) -> bool {
303 (self.num_patches() as f64 / take_indices.len() as f64)
304 < Self::PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN
305 }
306
307 pub fn take(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
309 if take_indices.is_empty() {
310 return Ok(None);
311 }
312 let take_indices = take_indices.to_primitive()?;
313 if self.is_map_faster_than_search(&take_indices) {
314 self.take_map(take_indices)
315 } else {
316 self.take_search(take_indices)
317 }
318 }
319
320 pub fn take_search(&self, take_indices: PrimitiveArray) -> VortexResult<Option<Self>> {
321 let new_length = take_indices.len();
322
323 let Some((new_indices, values_indices)) = match_each_integer_ptype!(take_indices.ptype(), |$I| {
324 take_search::<$I>(self.indices(), take_indices, self.offset())?
325 }) else {
326 return Ok(None);
327 };
328
329 Ok(Some(Self::new(
330 new_length,
331 0,
332 new_indices,
333 take(self.values(), &values_indices)?,
334 )))
335 }
336
337 pub fn take_map(&self, take_indices: PrimitiveArray) -> VortexResult<Option<Self>> {
338 let indices = self.indices.to_primitive()?;
339 let new_length = take_indices.len();
340
341 let Some((new_sparse_indices, value_indices)) = match_each_integer_ptype!(self.indices_ptype(), |$INDICES| {
342 match_each_integer_ptype!(take_indices.ptype(), |$TAKE_INDICES| {
343 take_map::<_, $TAKE_INDICES>(indices.as_slice::<$INDICES>(), take_indices, self.offset(), self.min_index()?, self.max_index()?)?
344 })
345 }) else {
346 return Ok(None);
347 };
348
349 Ok(Some(Patches::new(
350 new_length,
351 0,
352 new_sparse_indices,
353 take(self.values(), &value_indices)?,
354 )))
355 }
356
357 pub fn map_values<F>(self, f: F) -> VortexResult<Self>
358 where
359 F: FnOnce(ArrayRef) -> VortexResult<ArrayRef>,
360 {
361 let values = f(self.values)?;
362 if self.indices.len() != values.len() {
363 vortex_bail!(
364 "map_values must preserve length: expected {} received {}",
365 self.indices.len(),
366 values.len()
367 )
368 }
369 Ok(Self::new(self.array_len, self.offset, self.indices, values))
370 }
371}
372
373fn take_search<T: NativePType + TryFrom<usize>>(
374 indices: &dyn Array,
375 take_indices: PrimitiveArray,
376 indices_offset: usize,
377) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
378where
379 usize: TryFrom<T>,
380 VortexError: From<<usize as TryFrom<T>>::Error>,
381{
382 let take_indices_validity = take_indices.validity();
383 let take_indices = take_indices
384 .as_slice::<T>()
385 .iter()
386 .copied()
387 .map(usize::try_from)
388 .map_ok(|idx| idx + indices_offset)
389 .collect::<Result<Vec<_>, _>>()?;
390
391 let (values_indices, new_indices): (BufferMut<u64>, BufferMut<u64>) =
392 search_sorted_usize_many(indices, &take_indices, SearchSortedSide::Left)?
393 .iter()
394 .enumerate()
395 .filter_map(|(idx_in_take, search_result)| {
396 search_result
397 .to_found()
398 .map(|patch_idx| (patch_idx as u64, idx_in_take as u64))
399 })
400 .unzip();
401
402 if new_indices.is_empty() {
403 return Ok(None);
404 }
405
406 let new_indices = new_indices.into_array();
407 let values_validity = take_indices_validity.take(&new_indices)?;
408 Ok(Some((
409 new_indices,
410 PrimitiveArray::new(values_indices, values_validity).into_array(),
411 )))
412}
413
414fn take_map<I: NativePType + Hash + Eq + TryFrom<usize>, T: NativePType>(
415 indices: &[I],
416 take_indices: PrimitiveArray,
417 indices_offset: usize,
418 min_index: usize,
419 max_index: usize,
420) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
421where
422 usize: TryFrom<T>,
423 VortexError: From<<I as TryFrom<usize>>::Error>,
424{
425 let take_indices_validity = take_indices.validity();
426 let take_indices = take_indices.as_slice::<T>();
427 let offset_i = I::try_from(indices_offset)?;
428
429 let sparse_index_to_value_index: HashMap<I, usize> = indices
430 .iter()
431 .copied()
432 .map(|idx| idx - offset_i)
433 .enumerate()
434 .map(|(value_index, sparse_index)| (sparse_index, value_index))
435 .collect();
436 let (new_sparse_indices, value_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
437 .iter()
438 .copied()
439 .map(usize::try_from)
440 .process_results(|iter| {
441 iter.enumerate()
442 .filter(|(_, ti)| *ti >= min_index && *ti <= max_index)
443 .filter_map(|(new_sparse_index, take_sparse_index)| {
444 sparse_index_to_value_index
445 .get(
446 &I::try_from(take_sparse_index)
447 .vortex_expect("take_sparse_index is between min and max index"),
448 )
449 .map(|value_index| (new_sparse_index as u64, *value_index as u64))
450 })
451 .unzip()
452 })
453 .map_err(|_| vortex_err!("Failed to convert index to usize"))?;
454
455 if new_sparse_indices.is_empty() {
456 return Ok(None);
457 }
458
459 let new_sparse_indices = new_sparse_indices.into_array();
460 let values_validity = take_indices_validity.take(&new_sparse_indices)?;
461 Ok(Some((
462 new_sparse_indices,
463 PrimitiveArray::new(value_indices, values_validity).into_array(),
464 )))
465}
466
467fn filter_patches_with_mask<T: ToPrimitive + Copy + Ord>(
473 patch_indices: &[T],
474 offset: usize,
475 patch_values: &dyn Array,
476 mask_indices: &[usize],
477) -> VortexResult<Option<Patches>> {
478 let true_count = mask_indices.len();
479 let mut new_patch_indices = BufferMut::<u64>::with_capacity(true_count);
480 let mut new_mask_indices = Vec::with_capacity(true_count);
481
482 const STRIDE: usize = 4;
486
487 let mut mask_idx = 0usize;
488 let mut true_idx = 0usize;
489
490 while mask_idx < patch_indices.len() && true_idx < true_count {
491 if (mask_idx + STRIDE) < patch_indices.len() && (true_idx + STRIDE) < mask_indices.len() {
498 let left_min = patch_indices[mask_idx].to_usize().vortex_expect("left_min") - offset;
500 let left_max = patch_indices[mask_idx + STRIDE]
501 .to_usize()
502 .vortex_expect("left_max")
503 - offset;
504 let right_min = mask_indices[true_idx];
505 let right_max = mask_indices[true_idx + STRIDE];
506
507 if left_min > right_max {
508 true_idx += STRIDE;
510 continue;
511 } else if right_min > left_max {
512 mask_idx += STRIDE;
513 continue;
514 } else {
515 }
517 }
518
519 let left = patch_indices[mask_idx].to_usize().vortex_expect("left") - offset;
522 let right = mask_indices[true_idx];
523
524 match left.cmp(&right) {
525 Ordering::Less => {
526 mask_idx += 1;
527 }
528 Ordering::Greater => {
529 true_idx += 1;
530 }
531 Ordering::Equal => {
532 new_mask_indices.push(mask_idx);
534 new_patch_indices.push(true_idx as u64);
535
536 mask_idx += 1;
537 true_idx += 1;
538 }
539 }
540 }
541
542 if new_mask_indices.is_empty() {
543 return Ok(None);
544 }
545
546 let new_patch_indices = new_patch_indices.into_array();
547 let new_patch_values = filter(
548 patch_values,
549 &Mask::from_indices(patch_values.len(), new_mask_indices),
550 )?;
551
552 Ok(Some(Patches::new(
553 true_count,
554 0,
555 new_patch_indices,
556 new_patch_values,
557 )))
558}
559
560#[cfg(test)]
561mod test {
562 use rstest::{fixture, rstest};
563 use vortex_buffer::buffer;
564 use vortex_mask::Mask;
565
566 use crate::array::Array;
567 use crate::arrays::PrimitiveArray;
568 use crate::compute::{SearchResult, SearchSortedSide};
569 use crate::patches::Patches;
570 use crate::validity::Validity;
571 use crate::{IntoArray, ToCanonical};
572
573 #[test]
574 fn test_filter() {
575 let patches = Patches::new(
576 100,
577 0,
578 buffer![10u32, 11, 20].into_array(),
579 buffer![100, 110, 200].into_array(),
580 );
581
582 let filtered = patches
583 .filter(&Mask::from_indices(100, vec![10, 20, 30]))
584 .unwrap()
585 .unwrap();
586
587 let indices = filtered.indices().to_primitive().unwrap();
588 let values = filtered.values().to_primitive().unwrap();
589 assert_eq!(indices.as_slice::<u64>(), &[0, 1]);
590 assert_eq!(values.as_slice::<i32>(), &[100, 200]);
591 }
592
593 #[fixture]
594 fn patches() -> Patches {
595 Patches::new(
596 20,
597 0,
598 buffer![2u64, 9, 15].into_array(),
599 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
600 )
601 }
602
603 #[rstest]
604 fn search_larger_than(patches: Patches) {
605 let res = patches.search_sorted(66, SearchSortedSide::Left).unwrap();
606 assert_eq!(res, SearchResult::NotFound(16));
607 }
608
609 #[rstest]
610 fn search_less_than(patches: Patches) {
611 let res = patches.search_sorted(22, SearchSortedSide::Left).unwrap();
612 assert_eq!(res, SearchResult::NotFound(2));
613 }
614
615 #[rstest]
616 fn search_found(patches: Patches) {
617 let res = patches.search_sorted(44, SearchSortedSide::Left).unwrap();
618 assert_eq!(res, SearchResult::Found(9));
619 }
620
621 #[rstest]
622 fn search_not_found_right(patches: Patches) {
623 let res = patches.search_sorted(56, SearchSortedSide::Right).unwrap();
624 assert_eq!(res, SearchResult::NotFound(16));
625 }
626
627 #[rstest]
628 fn search_sliced(patches: Patches) {
629 let sliced = patches.slice(7, 20).unwrap().unwrap();
630 assert_eq!(
631 sliced.search_sorted(22, SearchSortedSide::Left).unwrap(),
632 SearchResult::NotFound(2)
633 );
634 }
635
636 #[test]
637 fn search_right() {
638 let patches = Patches::new(
639 2,
640 0,
641 buffer![0u64].into_array(),
642 PrimitiveArray::new(buffer![0u8], Validity::AllValid).into_array(),
643 );
644
645 assert_eq!(
646 patches.search_sorted(0, SearchSortedSide::Right).unwrap(),
647 SearchResult::Found(1)
648 );
649 assert_eq!(
650 patches.search_sorted(1, SearchSortedSide::Right).unwrap(),
651 SearchResult::NotFound(1)
652 );
653 }
654
655 #[rstest]
656 fn take_wit_nulls(patches: Patches) {
657 let taken = patches
658 .take(
659 &PrimitiveArray::new(buffer![9, 0], Validity::from_iter(vec![true, false]))
660 .into_array(),
661 )
662 .unwrap()
663 .unwrap();
664 let primitive_values = taken.values().to_primitive().unwrap();
665 assert_eq!(taken.array_len(), 2);
666 assert_eq!(primitive_values.as_slice::<i32>(), [44]);
667 assert_eq!(
668 primitive_values.validity_mask().unwrap(),
669 Mask::from_iter(vec![true])
670 );
671 }
672}