Skip to main content

vortex_alp/alp_rd/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::dict::TakeExecute;
9use vortex_array::builtins::ArrayBuiltins;
10use vortex_array::scalar::Scalar;
11use vortex_error::VortexResult;
12
13use crate::ALPRD;
14use crate::ALPRDArrayExt;
15
16impl TakeExecute for ALPRD {
17    fn take(
18        array: ArrayView<'_, Self>,
19        indices: &ArrayRef,
20        ctx: &mut ExecutionCtx,
21    ) -> VortexResult<Option<ArrayRef>> {
22        let taken_left_parts = array.left_parts().take(indices.clone())?;
23        let left_parts_exceptions = array
24            .left_parts_patches()
25            .map(|patches| patches.take(indices, ctx))
26            .transpose()?
27            .flatten()
28            .map(|p| {
29                let values_dtype = p
30                    .values()
31                    .dtype()
32                    .with_nullability(taken_left_parts.dtype().nullability());
33                p.cast_values(&values_dtype)
34            })
35            .transpose()?;
36        let right_parts = array
37            .right_parts()
38            .take(indices.clone())?
39            .fill_null(Scalar::zero_value(array.right_parts().dtype()))?;
40
41        Ok(Some(
42            ALPRD::try_new(
43                array
44                    .dtype()
45                    .with_nullability(taken_left_parts.dtype().nullability()),
46                taken_left_parts,
47                array.left_parts_dictionary().clone(),
48                right_parts,
49                array.right_bit_width(),
50                left_parts_exceptions,
51                ctx,
52            )?
53            .into_array(),
54        ))
55    }
56}
57
58#[cfg(test)]
59mod test {
60    use rstest::rstest;
61    use vortex_array::IntoArray;
62    use vortex_array::LEGACY_SESSION;
63    use vortex_array::VortexSessionExecute;
64    use vortex_array::arrays::PrimitiveArray;
65    use vortex_array::assert_arrays_eq;
66    use vortex_array::compute::conformance::take::test_take_conformance;
67
68    use crate::ALPRDArrayExt;
69    use crate::ALPRDFloat;
70    use crate::RDEncoder;
71
72    #[rstest]
73    #[case(0.1f32, 0.2f32, 3e25f32)]
74    #[case(0.1f64, 0.2f64, 3e100f64)]
75    fn test_take<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
76        use vortex_array::IntoArray as _;
77        use vortex_buffer::buffer;
78
79        let mut ctx = LEGACY_SESSION.create_execution_ctx();
80        let array = PrimitiveArray::from_iter([a, b, outlier]);
81        let encoded = RDEncoder::new(&[a, b]).encode(array.as_view(), &mut ctx);
82
83        assert!(encoded.left_parts_patches().is_some());
84        assert!(
85            encoded
86                .left_parts_patches()
87                .unwrap()
88                .dtype()
89                .is_unsigned_int()
90        );
91
92        let taken = encoded
93            .take(buffer![0, 2].into_array())
94            .unwrap()
95            .execute::<PrimitiveArray>(&mut ctx)
96            .unwrap();
97
98        assert_arrays_eq!(taken, PrimitiveArray::from_iter([a, outlier]));
99    }
100
101    #[rstest]
102    #[case(0.1f32, 0.2f32, 3e25f32)]
103    #[case(0.1f64, 0.2f64, 3e100f64)]
104    fn take_with_nulls<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
105        let mut ctx = LEGACY_SESSION.create_execution_ctx();
106        let array = PrimitiveArray::from_iter([a, b, outlier]);
107        let encoded = RDEncoder::new(&[a, b]).encode(array.as_view(), &mut ctx);
108
109        assert!(encoded.left_parts_patches().is_some());
110        assert!(
111            encoded
112                .left_parts_patches()
113                .unwrap()
114                .dtype()
115                .is_unsigned_int()
116        );
117
118        let taken = encoded
119            .take(PrimitiveArray::from_option_iter([Some(0), Some(2), None]).into_array())
120            .unwrap()
121            .execute::<PrimitiveArray>(&mut ctx)
122            .unwrap();
123
124        assert_arrays_eq!(
125            taken,
126            PrimitiveArray::from_option_iter([Some(a), Some(outlier), None])
127        );
128    }
129
130    #[rstest]
131    #[case(0.1f32, 0.2f32, 3e25f32)]
132    #[case(0.1f64, 0.2f64, 3e100f64)]
133    fn test_take_conformance_alprd<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
134        let mut ctx = LEGACY_SESSION.create_execution_ctx();
135        test_take_conformance(
136            &RDEncoder::new(&[a, b])
137                .encode(
138                    PrimitiveArray::from_iter([a, b, outlier, b, outlier]).as_view(),
139                    &mut ctx,
140                )
141                .into_array(),
142        );
143    }
144
145    #[rstest]
146    #[case(0.1f32, 3e25f32)]
147    #[case(0.5f64, 1e100f64)]
148    fn test_take_with_nulls_conformance<T: ALPRDFloat>(#[case] a: T, #[case] outlier: T) {
149        let mut ctx = LEGACY_SESSION.create_execution_ctx();
150        test_take_conformance(
151            &RDEncoder::new(&[a])
152                .encode(
153                    PrimitiveArray::from_option_iter([Some(a), None, Some(outlier), Some(a), None])
154                        .as_view(),
155                    &mut ctx,
156                )
157                .into_array(),
158        );
159    }
160}