vortex_dict/compute/
binary_numeric.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{NumericKernel, NumericKernelAdapter, numeric};
6use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8use vortex_scalar::NumericOperator;
9
10use crate::{DictArray, DictVTable};
11
12impl NumericKernel for DictVTable {
13    fn numeric(
14        &self,
15        array: &DictArray,
16        rhs: &dyn Array,
17        op: NumericOperator,
18    ) -> VortexResult<Option<ArrayRef>> {
19        // if we have more values than codes, it is faster to canonicalise first.
20        if array.values().len() > array.codes().len() {
21            return Ok(None);
22        }
23
24        let Some(rhs_scalar) = rhs.as_constant() else {
25            return Ok(None);
26        };
27        let rhs_const_array = ConstantArray::new(rhs_scalar, array.values().len()).into_array();
28
29        Ok(Some(
30            DictArray::try_new(
31                array.codes().clone(),
32                numeric(array.values(), &rhs_const_array, op)?,
33            )?
34            .into_array(),
35        ))
36    }
37}
38
39register_kernel!(NumericKernelAdapter(DictVTable).lift());
40
41#[cfg(test)]
42mod tests {
43    use vortex_array::ArrayRef;
44    use vortex_array::arrays::PrimitiveArray;
45    use vortex_array::compute::conformance::binary_numeric::test_numeric;
46
47    use crate::builders::dict_encode;
48
49    fn sliced_dict_array() -> ArrayRef {
50        let reference = PrimitiveArray::from_option_iter([
51            Some(42),
52            Some(-9),
53            None,
54            Some(42),
55            Some(1),
56            Some(5),
57        ]);
58        let dict = dict_encode(reference.as_ref()).unwrap();
59        dict.slice(1, 4).unwrap()
60    }
61
62    #[test]
63    fn test_dict_binary_numeric() {
64        let array = sliced_dict_array();
65        test_numeric::<i32>(array)
66    }
67}