vortex_alp/alp_rd/compute/
take.rs

1use vortex_array::compute::{TakeFn, fill_null, take};
2use vortex_array::{Array, ArrayRef};
3use vortex_error::VortexResult;
4use vortex_scalar::{Scalar, ScalarValue};
5
6use crate::{ALPRDArray, ALPRDEncoding};
7
8impl TakeFn<&ALPRDArray> for ALPRDEncoding {
9    fn take(&self, array: &ALPRDArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
10        let taken_left_parts = take(array.left_parts(), indices)?;
11        let left_parts_exceptions = array
12            .left_parts_patches()
13            .map(|patches| patches.take(indices))
14            .transpose()?
15            .flatten()
16            .map(|p| {
17                let values_dtype = p
18                    .values()
19                    .dtype()
20                    .with_nullability(taken_left_parts.dtype().nullability());
21                p.cast_values(&values_dtype)
22            })
23            .transpose()?;
24        let right_parts = fill_null(
25            &take(array.right_parts(), indices)?,
26            Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)),
27        )?;
28
29        Ok(ALPRDArray::try_new(
30            array
31                .dtype()
32                .with_nullability(taken_left_parts.dtype().nullability()),
33            taken_left_parts,
34            array.left_parts_dictionary().clone(),
35            right_parts,
36            array.right_bit_width(),
37            left_parts_exceptions,
38        )?
39        .into_array())
40    }
41}
42
43#[cfg(test)]
44mod test {
45    use rstest::rstest;
46    use vortex_array::arrays::PrimitiveArray;
47    use vortex_array::compute::take;
48    use vortex_array::{Array, ToCanonical};
49
50    use crate::{ALPRDFloat, RDEncoder};
51
52    #[rstest]
53    #[case(0.1f32, 0.2f32, 3e25f32)]
54    #[case(0.1f64, 0.2f64, 3e100f64)]
55    fn test_take<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
56        let array = PrimitiveArray::from_iter([a, b, outlier]);
57        let encoded = RDEncoder::new(&[a, b]).encode(&array);
58
59        assert!(encoded.left_parts_patches().is_some());
60        assert!(
61            encoded
62                .left_parts_patches()
63                .unwrap()
64                .dtype()
65                .is_unsigned_int()
66        );
67
68        let taken = take(&encoded, &PrimitiveArray::from_iter([0, 2]))
69            .unwrap()
70            .to_primitive()
71            .unwrap();
72
73        assert_eq!(taken.as_slice::<T>(), &[a, outlier]);
74    }
75
76    #[rstest]
77    #[case(0.1f32, 0.2f32, 3e25f32)]
78    #[case(0.1f64, 0.2f64, 3e100f64)]
79    fn take_with_nulls<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
80        let array = PrimitiveArray::from_iter([a, b, outlier]);
81        let encoded = RDEncoder::new(&[a, b]).encode(&array);
82
83        assert!(encoded.left_parts_patches().is_some());
84        assert!(
85            encoded
86                .left_parts_patches()
87                .unwrap()
88                .dtype()
89                .is_unsigned_int()
90        );
91
92        let taken = take(
93            &encoded,
94            &PrimitiveArray::from_option_iter([Some(0), Some(2), None]),
95        )
96        .unwrap()
97        .to_primitive()
98        .unwrap();
99
100        assert_eq!(taken.as_slice::<T>()[0], a);
101        assert_eq!(taken.as_slice::<T>()[1], outlier);
102        assert!(!taken.validity_mask().unwrap().value(2));
103    }
104}