Skip to main content

vortex_alp/alp_rd/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::dict::TakeExecute;
9use vortex_array::builtins::ArrayBuiltins;
10use vortex_array::scalar::Scalar;
11use vortex_error::VortexResult;
12
13use crate::ALPRD;
14use crate::ALPRDArrayExt;
15
16impl TakeExecute for ALPRD {
17    fn take(
18        array: ArrayView<'_, Self>,
19        indices: &ArrayRef,
20        ctx: &mut ExecutionCtx,
21    ) -> VortexResult<Option<ArrayRef>> {
22        let taken_left_parts = array.left_parts().take(indices.clone())?;
23        let left_parts_exceptions = array
24            .left_parts_patches()
25            .map(|patches| patches.take(indices, ctx))
26            .transpose()?
27            .flatten()
28            .map(|p| {
29                let values_dtype = p
30                    .values()
31                    .dtype()
32                    .with_nullability(taken_left_parts.dtype().nullability());
33                p.cast_values(&values_dtype)
34            })
35            .transpose()?;
36        let right_parts = array
37            .right_parts()
38            .take(indices.clone())?
39            .fill_null(Scalar::zero_value(array.right_parts().dtype()))?;
40
41        Ok(Some(
42            ALPRD::try_new(
43                array
44                    .dtype()
45                    .with_nullability(taken_left_parts.dtype().nullability()),
46                taken_left_parts,
47                array.left_parts_dictionary().clone(),
48                right_parts,
49                array.right_bit_width(),
50                left_parts_exceptions,
51            )?
52            .into_array(),
53        ))
54    }
55}
56
57#[cfg(test)]
58mod test {
59    use rstest::rstest;
60    use vortex_array::IntoArray;
61    use vortex_array::ToCanonical;
62    use vortex_array::arrays::PrimitiveArray;
63    use vortex_array::assert_arrays_eq;
64    use vortex_array::compute::conformance::take::test_take_conformance;
65
66    use crate::ALPRDArrayExt;
67    use crate::ALPRDFloat;
68    use crate::RDEncoder;
69
70    #[rstest]
71    #[case(0.1f32, 0.2f32, 3e25f32)]
72    #[case(0.1f64, 0.2f64, 3e100f64)]
73    fn test_take<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
74        use vortex_array::IntoArray as _;
75        use vortex_buffer::buffer;
76
77        let array = PrimitiveArray::from_iter([a, b, outlier]);
78        let encoded = RDEncoder::new(&[a, b]).encode(&array);
79
80        assert!(encoded.left_parts_patches().is_some());
81        assert!(
82            encoded
83                .left_parts_patches()
84                .unwrap()
85                .dtype()
86                .is_unsigned_int()
87        );
88
89        let taken = encoded
90            .take(buffer![0, 2].into_array())
91            .unwrap()
92            .to_primitive();
93
94        assert_arrays_eq!(taken, PrimitiveArray::from_iter([a, outlier]));
95    }
96
97    #[rstest]
98    #[case(0.1f32, 0.2f32, 3e25f32)]
99    #[case(0.1f64, 0.2f64, 3e100f64)]
100    fn take_with_nulls<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
101        let array = PrimitiveArray::from_iter([a, b, outlier]);
102        let encoded = RDEncoder::new(&[a, b]).encode(&array);
103
104        assert!(encoded.left_parts_patches().is_some());
105        assert!(
106            encoded
107                .left_parts_patches()
108                .unwrap()
109                .dtype()
110                .is_unsigned_int()
111        );
112
113        let taken = encoded
114            .take(PrimitiveArray::from_option_iter([Some(0), Some(2), None]).into_array())
115            .unwrap()
116            .to_primitive();
117
118        assert_arrays_eq!(
119            taken,
120            PrimitiveArray::from_option_iter([Some(a), Some(outlier), None])
121        );
122    }
123
124    #[rstest]
125    #[case(0.1f32, 0.2f32, 3e25f32)]
126    #[case(0.1f64, 0.2f64, 3e100f64)]
127    fn test_take_conformance_alprd<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
128        test_take_conformance(
129            &RDEncoder::new(&[a, b])
130                .encode(&PrimitiveArray::from_iter([a, b, outlier, b, outlier]))
131                .into_array(),
132        );
133    }
134
135    #[rstest]
136    #[case(0.1f32, 3e25f32)]
137    #[case(0.5f64, 1e100f64)]
138    fn test_take_with_nulls_conformance<T: ALPRDFloat>(#[case] a: T, #[case] outlier: T) {
139        test_take_conformance(
140            &RDEncoder::new(&[a])
141                .encode(&PrimitiveArray::from_option_iter([
142                    Some(a),
143                    None,
144                    Some(outlier),
145                    Some(a),
146                    None,
147                ]))
148                .into_array(),
149        );
150    }
151}