vortex_alp/alp_rd/compute/
take.rs1use vortex_array::compute::{TakeKernel, TakeKernelAdapter, fill_null, take};
5use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
6use vortex_error::VortexResult;
7use vortex_scalar::{Scalar, ScalarValue};
8
9use crate::{ALPRDArray, ALPRDVTable};
10
11impl TakeKernel for ALPRDVTable {
12 fn take(&self, array: &ALPRDArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
13 let taken_left_parts = take(array.left_parts(), indices)?;
14 let left_parts_exceptions = array
15 .left_parts_patches()
16 .map(|patches| patches.take(indices))
17 .transpose()?
18 .flatten()
19 .map(|p| {
20 let values_dtype = p
21 .values()
22 .dtype()
23 .with_nullability(taken_left_parts.dtype().nullability());
24 p.cast_values(&values_dtype)
25 })
26 .transpose()?;
27 let right_parts = fill_null(
28 &take(array.right_parts(), indices)?,
29 &Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)),
30 )?;
31
32 Ok(ALPRDArray::try_new(
33 array
34 .dtype()
35 .with_nullability(taken_left_parts.dtype().nullability()),
36 taken_left_parts,
37 array.left_parts_dictionary().clone(),
38 right_parts,
39 array.right_bit_width(),
40 left_parts_exceptions,
41 )?
42 .into_array())
43 }
44}
45
46register_kernel!(TakeKernelAdapter(ALPRDVTable).lift());
47
48#[cfg(test)]
49mod test {
50 use rstest::rstest;
51 use vortex_array::arrays::PrimitiveArray;
52 use vortex_array::compute::conformance::take::test_take_conformance;
53 use vortex_array::compute::take;
54 use vortex_array::{ToCanonical, assert_arrays_eq};
55
56 use crate::{ALPRDFloat, RDEncoder};
57
58 #[rstest]
59 #[case(0.1f32, 0.2f32, 3e25f32)]
60 #[case(0.1f64, 0.2f64, 3e100f64)]
61 fn test_take<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
62 use vortex_array::IntoArray as _;
63 use vortex_buffer::buffer;
64
65 let array = PrimitiveArray::from_iter([a, b, outlier]);
66 let encoded = RDEncoder::new(&[a, b]).encode(&array);
67
68 assert!(encoded.left_parts_patches().is_some());
69 assert!(
70 encoded
71 .left_parts_patches()
72 .unwrap()
73 .dtype()
74 .is_unsigned_int()
75 );
76
77 let taken = take(encoded.as_ref(), buffer![0, 2].into_array().as_ref())
78 .unwrap()
79 .to_primitive();
80
81 assert_arrays_eq!(taken, PrimitiveArray::from_iter([a, outlier]));
82 }
83
84 #[rstest]
85 #[case(0.1f32, 0.2f32, 3e25f32)]
86 #[case(0.1f64, 0.2f64, 3e100f64)]
87 fn take_with_nulls<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
88 let array = PrimitiveArray::from_iter([a, b, outlier]);
89 let encoded = RDEncoder::new(&[a, b]).encode(&array);
90
91 assert!(encoded.left_parts_patches().is_some());
92 assert!(
93 encoded
94 .left_parts_patches()
95 .unwrap()
96 .dtype()
97 .is_unsigned_int()
98 );
99
100 let taken = take(
101 encoded.as_ref(),
102 PrimitiveArray::from_option_iter([Some(0), Some(2), None]).as_ref(),
103 )
104 .unwrap()
105 .to_primitive();
106
107 assert_arrays_eq!(
108 taken,
109 PrimitiveArray::from_option_iter([Some(a), Some(outlier), None])
110 );
111 }
112
113 #[rstest]
114 #[case(0.1f32, 0.2f32, 3e25f32)]
115 #[case(0.1f64, 0.2f64, 3e100f64)]
116 fn test_take_conformance_alprd<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
117 test_take_conformance(
118 &RDEncoder::new(&[a, b])
119 .encode(&PrimitiveArray::from_iter([a, b, outlier, b, outlier]))
120 .to_array(),
121 );
122 }
123
124 #[rstest]
125 #[case(0.1f32, 3e25f32)]
126 #[case(0.5f64, 1e100f64)]
127 fn test_take_with_nulls_conformance<T: ALPRDFloat>(#[case] a: T, #[case] outlier: T) {
128 test_take_conformance(
129 &RDEncoder::new(&[a])
130 .encode(&PrimitiveArray::from_option_iter([
131 Some(a),
132 None,
133 Some(outlier),
134 Some(a),
135 None,
136 ]))
137 .to_array(),
138 );
139 }
140}