vortex_dict/compute/
binary_numeric.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{NumericKernel, NumericKernelAdapter, numeric};
6use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8use vortex_scalar::NumericOperator;
9
10use crate::{DictArray, DictVTable};
11
12impl NumericKernel for DictVTable {
13    fn numeric(
14        &self,
15        array: &DictArray,
16        rhs: &dyn Array,
17        op: NumericOperator,
18    ) -> VortexResult<Option<ArrayRef>> {
19        // if we have more values than codes, it is faster to canonicalise first.
20        if array.values().len() > array.codes().len() {
21            return Ok(None);
22        }
23
24        let Some(rhs_scalar) = rhs.as_constant() else {
25            return Ok(None);
26        };
27        let rhs_const_array = ConstantArray::new(rhs_scalar, array.values().len()).into_array();
28
29        // SAFETY: applying numeric fn to values does not change codes validity
30        unsafe {
31            Ok(Some(
32                DictArray::new_unchecked(
33                    array.codes().clone(),
34                    numeric(array.values(), &rhs_const_array, op)?,
35                )
36                .into_array(),
37            ))
38        }
39    }
40}
41
42register_kernel!(NumericKernelAdapter(DictVTable).lift());
43
44#[cfg(test)]
45mod tests {
46    use rstest::rstest;
47    use vortex_array::ArrayRef;
48    use vortex_array::arrays::PrimitiveArray;
49    use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
50
51    use crate::builders::dict_encode;
52
53    fn sliced_dict_array() -> ArrayRef {
54        let reference = PrimitiveArray::from_option_iter([
55            Some(42),
56            Some(-9),
57            None,
58            Some(42),
59            Some(1),
60            Some(5),
61        ]);
62        let dict = dict_encode(reference.as_ref()).unwrap();
63        dict.slice(1..4)
64    }
65
66    #[test]
67    fn test_dict_binary_numeric() {
68        let array = sliced_dict_array();
69        test_binary_numeric_array(array)
70    }
71
72    use vortex_array::IntoArray;
73
74    #[rstest]
75    #[case::dict_i32_basic(dict_encode(PrimitiveArray::from_iter([10i32, 20, 10, 30, 20, 10]).as_ref()).unwrap().into_array())]
76    #[case::dict_u32_basic(dict_encode(PrimitiveArray::from_iter([100u32, 200, 100, 300, 200]).as_ref()).unwrap().into_array())]
77    #[case::dict_i64_basic(dict_encode(PrimitiveArray::from_iter([1000i64, 2000, 1000, 3000, 2000, 1000]).as_ref()).unwrap().into_array())]
78    #[case::dict_u64_basic(dict_encode(PrimitiveArray::from_iter([5000u64, 6000, 5000, 7000, 6000]).as_ref()).unwrap().into_array())]
79    #[case::dict_f32_basic(dict_encode(PrimitiveArray::from_iter([1.5f32, 2.5, 1.5, 3.5, 2.5]).as_ref()).unwrap().into_array())]
80    #[case::dict_f64_basic(dict_encode(PrimitiveArray::from_iter([10.1f64, 20.2, 10.1, 30.3, 20.2]).as_ref()).unwrap().into_array())]
81    #[case::dict_i32_sliced(dict_encode(PrimitiveArray::from_iter([100i32, 200, 100, 300, 200, 100]).as_ref()).unwrap().slice(1..5))]
82    #[case::dict_nullable(dict_encode(PrimitiveArray::from_option_iter([Some(42i32), None, Some(42), Some(1), None]).as_ref()).unwrap().into_array())]
83    fn test_dict_binary_numeric_rstest(#[case] array: ArrayRef) {
84        test_binary_numeric_array(array)
85    }
86}