vortex_alp/alp_rd/compute/
take.rs1use vortex_array::compute::{TakeFn, fill_null, take};
2use vortex_array::{Array, ArrayRef};
3use vortex_error::VortexResult;
4use vortex_scalar::{Scalar, ScalarValue};
5
6use crate::{ALPRDArray, ALPRDEncoding};
7
8impl TakeFn<&ALPRDArray> for ALPRDEncoding {
9 fn take(&self, array: &ALPRDArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
10 let taken_left_parts = take(array.left_parts(), indices)?;
11 let left_parts_exceptions = array
12 .left_parts_patches()
13 .map(|patches| patches.take(indices))
14 .transpose()?
15 .flatten()
16 .map(|p| {
17 let values_dtype = p
18 .values()
19 .dtype()
20 .with_nullability(taken_left_parts.dtype().nullability());
21 p.cast_values(&values_dtype)
22 })
23 .transpose()?;
24 let right_parts = fill_null(
25 &take(array.right_parts(), indices)?,
26 Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)),
27 )?;
28
29 Ok(ALPRDArray::try_new(
30 array
31 .dtype()
32 .with_nullability(taken_left_parts.dtype().nullability()),
33 taken_left_parts,
34 array.left_parts_dictionary().clone(),
35 right_parts,
36 array.right_bit_width(),
37 left_parts_exceptions,
38 )?
39 .into_array())
40 }
41}
42
43#[cfg(test)]
44mod test {
45 use rstest::rstest;
46 use vortex_array::arrays::PrimitiveArray;
47 use vortex_array::compute::take;
48 use vortex_array::{Array, ToCanonical};
49
50 use crate::{ALPRDFloat, RDEncoder};
51
52 #[rstest]
53 #[case(0.1f32, 0.2f32, 3e25f32)]
54 #[case(0.1f64, 0.2f64, 3e100f64)]
55 fn test_take<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
56 let array = PrimitiveArray::from_iter([a, b, outlier]);
57 let encoded = RDEncoder::new(&[a, b]).encode(&array);
58
59 assert!(encoded.left_parts_patches().is_some());
60 assert!(
61 encoded
62 .left_parts_patches()
63 .unwrap()
64 .dtype()
65 .is_unsigned_int()
66 );
67
68 let taken = take(&encoded, &PrimitiveArray::from_iter([0, 2]))
69 .unwrap()
70 .to_primitive()
71 .unwrap();
72
73 assert_eq!(taken.as_slice::<T>(), &[a, outlier]);
74 }
75
76 #[rstest]
77 #[case(0.1f32, 0.2f32, 3e25f32)]
78 #[case(0.1f64, 0.2f64, 3e100f64)]
79 fn take_with_nulls<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
80 let array = PrimitiveArray::from_iter([a, b, outlier]);
81 let encoded = RDEncoder::new(&[a, b]).encode(&array);
82
83 assert!(encoded.left_parts_patches().is_some());
84 assert!(
85 encoded
86 .left_parts_patches()
87 .unwrap()
88 .dtype()
89 .is_unsigned_int()
90 );
91
92 let taken = take(
93 &encoded,
94 &PrimitiveArray::from_option_iter([Some(0), Some(2), None]),
95 )
96 .unwrap()
97 .to_primitive()
98 .unwrap();
99
100 assert_eq!(taken.as_slice::<T>()[0], a);
101 assert_eq!(taken.as_slice::<T>()[1], outlier);
102 assert!(!taken.validity_mask().unwrap().value(2));
103 }
104}