vortex_array/arrays/dict/compute/
binary_numeric.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_scalar::NumericOperator;
6
7use super::DictArray;
8use super::DictVTable;
9use crate::Array;
10use crate::ArrayRef;
11use crate::IntoArray;
12use crate::arrays::ConstantArray;
13use crate::compute::NumericKernel;
14use crate::compute::NumericKernelAdapter;
15use crate::compute::numeric;
16use crate::register_kernel;
17
18impl NumericKernel for DictVTable {
19    fn numeric(
20        &self,
21        lhs: &DictArray,
22        rhs: &dyn Array,
23        op: NumericOperator,
24    ) -> VortexResult<Option<ArrayRef>> {
25        // If we have more values than codes, it is faster to canonicalise first.
26        if lhs.values().len() > lhs.codes().len() {
27            return Ok(None);
28        }
29
30        // Only push down if all values are referenced to avoid incorrect results
31        // See: https://github.com/vortex-data/vortex/pull/4560
32        // Unchecked operation will be fine to pushdown.
33        if !lhs.has_all_values_referenced() {
34            return Ok(None);
35        }
36
37        // If the RHS is constant, then we just need to apply the operation to our encoded values.
38        if let Some(rhs_scalar) = rhs.as_constant() {
39            let values_result = numeric(
40                lhs.values(),
41                ConstantArray::new(rhs_scalar, lhs.values().len()).as_ref(),
42                op,
43            )?;
44
45            // SAFETY: values len preserved, codes all still point to valid values
46            // all_values_referenced preserved since operation doesn't change which values are referenced
47            let result = unsafe {
48                DictArray::new_unchecked(lhs.codes().clone(), values_result)
49                    .set_all_values_referenced(lhs.has_all_values_referenced())
50                    .into_array()
51            };
52
53            return Ok(Some(result));
54        }
55
56        // It's a little more complex, but we could perform binary operations against the dictionary
57        // values in the future.
58        Ok(None)
59    }
60}
61
62register_kernel!(NumericKernelAdapter(DictVTable).lift());
63
64#[cfg(test)]
65mod tests {
66    use vortex_buffer::buffer;
67    use vortex_scalar::NumericOperator;
68
69    use crate::IntoArray;
70    use crate::arrays::ConstantArray;
71    use crate::arrays::PrimitiveArray;
72    use crate::arrays::dict::DictArray;
73    use crate::assert_arrays_eq;
74    use crate::compute::numeric;
75
76    #[test]
77    fn test_add_const() {
78        // Create a dict with all_values_referenced = true
79        let dict = unsafe {
80            DictArray::new_unchecked(
81                buffer![0u32, 1, 2, 0, 1].into_array(),
82                buffer![10i32, 20, 30].into_array(),
83            )
84            .set_all_values_referenced(true)
85        };
86
87        let res = numeric(
88            dict.as_ref(),
89            ConstantArray::new(5i32, 5).as_ref(),
90            NumericOperator::Add,
91        )
92        .unwrap();
93
94        let expected = PrimitiveArray::from_iter([15i32, 25, 35, 15, 25]);
95        assert_arrays_eq!(res.to_canonical().into_array(), expected.to_array());
96    }
97
98    #[test]
99    fn test_mul_const() {
100        // Create a dict with all_values_referenced = true
101        let dict = unsafe {
102            DictArray::new_unchecked(
103                buffer![0u32, 1, 2, 1, 0].into_array(),
104                buffer![2i32, 3, 5].into_array(),
105            )
106            .set_all_values_referenced(true)
107        };
108
109        let res = numeric(
110            dict.as_ref(),
111            ConstantArray::new(10i32, 5).as_ref(),
112            NumericOperator::Mul,
113        )
114        .unwrap();
115
116        let expected = PrimitiveArray::from_iter([20i32, 30, 50, 30, 20]);
117        assert_arrays_eq!(res.to_canonical().into_array(), expected.to_array());
118    }
119
120    #[test]
121    fn test_no_pushdown_when_not_all_values_referenced() {
122        // Create a dict with all_values_referenced = false (default)
123        let dict = DictArray::try_new(
124            buffer![0u32, 1, 0, 1].into_array(),
125            buffer![10i32, 20, 30].into_array(), // value at index 2 is not referenced
126        )
127        .unwrap();
128
129        // Should return None, indicating no pushdown
130        let res = numeric(
131            dict.as_ref(),
132            ConstantArray::new(5i32, 4).as_ref(),
133            NumericOperator::Add,
134        )
135        .unwrap();
136
137        // Verify the result by canonicalizing
138        let expected = PrimitiveArray::from_iter([15i32, 25, 15, 25]);
139        assert_arrays_eq!(res.to_canonical().into_array(), expected.to_array());
140    }
141
142    #[test]
143    fn test_sub_const() {
144        // Create a dict with all_values_referenced = true
145        let dict = unsafe {
146            DictArray::new_unchecked(
147                buffer![0u32, 1, 2].into_array(),
148                buffer![100i32, 50, 25].into_array(),
149            )
150            .set_all_values_referenced(true)
151        };
152
153        let res = numeric(
154            dict.as_ref(),
155            ConstantArray::new(10i32, 3).as_ref(),
156            NumericOperator::Sub,
157        )
158        .unwrap();
159
160        let expected = PrimitiveArray::from_iter([90i32, 40, 15]);
161        assert_arrays_eq!(res.to_canonical().into_array(), expected.to_array());
162    }
163}