vortex_alp/alp_rd/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::compute::{TakeKernel, TakeKernelAdapter, fill_null, take};
5use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
6use vortex_error::VortexResult;
7use vortex_scalar::{Scalar, ScalarValue};
8
9use crate::{ALPRDArray, ALPRDVTable};
10
11impl TakeKernel for ALPRDVTable {
12    fn take(&self, array: &ALPRDArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
13        let taken_left_parts = take(array.left_parts(), indices)?;
14        let left_parts_exceptions = array
15            .left_parts_patches()
16            .map(|patches| patches.take(indices))
17            .transpose()?
18            .flatten()
19            .map(|p| {
20                let values_dtype = p
21                    .values()
22                    .dtype()
23                    .with_nullability(taken_left_parts.dtype().nullability());
24                p.cast_values(&values_dtype)
25            })
26            .transpose()?;
27        let right_parts = fill_null(
28            &take(array.right_parts(), indices)?,
29            &Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)),
30        )?;
31
32        Ok(ALPRDArray::try_new(
33            array
34                .dtype()
35                .with_nullability(taken_left_parts.dtype().nullability()),
36            taken_left_parts,
37            array.left_parts_dictionary().clone(),
38            right_parts,
39            array.right_bit_width(),
40            left_parts_exceptions,
41        )?
42        .into_array())
43    }
44}
45
46register_kernel!(TakeKernelAdapter(ALPRDVTable).lift());
47
48#[cfg(test)]
49mod test {
50    use rstest::rstest;
51    use vortex_array::arrays::PrimitiveArray;
52    use vortex_array::compute::conformance::take::test_take_conformance;
53    use vortex_array::compute::take;
54    use vortex_array::{ToCanonical, assert_arrays_eq};
55
56    use crate::{ALPRDFloat, RDEncoder};
57
58    #[rstest]
59    #[case(0.1f32, 0.2f32, 3e25f32)]
60    #[case(0.1f64, 0.2f64, 3e100f64)]
61    fn test_take<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
62        use vortex_array::IntoArray as _;
63        use vortex_buffer::buffer;
64
65        let array = PrimitiveArray::from_iter([a, b, outlier]);
66        let encoded = RDEncoder::new(&[a, b]).encode(&array);
67
68        assert!(encoded.left_parts_patches().is_some());
69        assert!(
70            encoded
71                .left_parts_patches()
72                .unwrap()
73                .dtype()
74                .is_unsigned_int()
75        );
76
77        let taken = take(encoded.as_ref(), buffer![0, 2].into_array().as_ref())
78            .unwrap()
79            .to_primitive();
80
81        assert_arrays_eq!(taken, PrimitiveArray::from_iter([a, outlier]));
82    }
83
84    #[rstest]
85    #[case(0.1f32, 0.2f32, 3e25f32)]
86    #[case(0.1f64, 0.2f64, 3e100f64)]
87    fn take_with_nulls<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
88        let array = PrimitiveArray::from_iter([a, b, outlier]);
89        let encoded = RDEncoder::new(&[a, b]).encode(&array);
90
91        assert!(encoded.left_parts_patches().is_some());
92        assert!(
93            encoded
94                .left_parts_patches()
95                .unwrap()
96                .dtype()
97                .is_unsigned_int()
98        );
99
100        let taken = take(
101            encoded.as_ref(),
102            PrimitiveArray::from_option_iter([Some(0), Some(2), None]).as_ref(),
103        )
104        .unwrap()
105        .to_primitive();
106
107        assert_arrays_eq!(
108            taken,
109            PrimitiveArray::from_option_iter([Some(a), Some(outlier), None])
110        );
111    }
112
113    #[rstest]
114    #[case(0.1f32, 0.2f32, 3e25f32)]
115    #[case(0.1f64, 0.2f64, 3e100f64)]
116    fn test_take_conformance_alprd<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
117        test_take_conformance(
118            &RDEncoder::new(&[a, b])
119                .encode(&PrimitiveArray::from_iter([a, b, outlier, b, outlier]))
120                .to_array(),
121        );
122    }
123
124    #[rstest]
125    #[case(0.1f32, 3e25f32)]
126    #[case(0.5f64, 1e100f64)]
127    fn test_take_with_nulls_conformance<T: ALPRDFloat>(#[case] a: T, #[case] outlier: T) {
128        test_take_conformance(
129            &RDEncoder::new(&[a])
130                .encode(&PrimitiveArray::from_option_iter([
131                    Some(a),
132                    None,
133                    Some(outlier),
134                    Some(a),
135                    None,
136                ]))
137                .to_array(),
138        );
139    }
140}