vortex_array/arrays/dict/compute/
binary_numeric.rs1use vortex_error::VortexResult;
5use vortex_scalar::NumericOperator;
6
7use super::DictArray;
8use super::DictVTable;
9use crate::Array;
10use crate::ArrayRef;
11use crate::IntoArray;
12use crate::arrays::ConstantArray;
13use crate::compute::NumericKernel;
14use crate::compute::NumericKernelAdapter;
15use crate::compute::numeric;
16use crate::register_kernel;
17
18impl NumericKernel for DictVTable {
19 fn numeric(
20 &self,
21 lhs: &DictArray,
22 rhs: &dyn Array,
23 op: NumericOperator,
24 ) -> VortexResult<Option<ArrayRef>> {
25 if lhs.values().len() > lhs.codes().len() {
27 return Ok(None);
28 }
29
30 if !lhs.has_all_values_referenced() {
34 return Ok(None);
35 }
36
37 if let Some(rhs_scalar) = rhs.as_constant() {
39 let values_result = numeric(
40 lhs.values(),
41 ConstantArray::new(rhs_scalar, lhs.values().len()).as_ref(),
42 op,
43 )?;
44
45 let result = unsafe {
48 DictArray::new_unchecked(lhs.codes().clone(), values_result)
49 .set_all_values_referenced(lhs.has_all_values_referenced())
50 .into_array()
51 };
52
53 return Ok(Some(result));
54 }
55
56 Ok(None)
59 }
60}
61
62register_kernel!(NumericKernelAdapter(DictVTable).lift());
63
64#[cfg(test)]
65mod tests {
66 use vortex_buffer::buffer;
67 use vortex_scalar::NumericOperator;
68
69 use crate::IntoArray;
70 use crate::arrays::ConstantArray;
71 use crate::arrays::PrimitiveArray;
72 use crate::arrays::dict::DictArray;
73 use crate::assert_arrays_eq;
74 use crate::compute::numeric;
75
76 #[test]
77 fn test_add_const() {
78 let dict = unsafe {
80 DictArray::new_unchecked(
81 buffer![0u32, 1, 2, 0, 1].into_array(),
82 buffer![10i32, 20, 30].into_array(),
83 )
84 .set_all_values_referenced(true)
85 };
86
87 let res = numeric(
88 dict.as_ref(),
89 ConstantArray::new(5i32, 5).as_ref(),
90 NumericOperator::Add,
91 )
92 .unwrap();
93
94 let expected = PrimitiveArray::from_iter([15i32, 25, 35, 15, 25]);
95 assert_arrays_eq!(res.to_canonical().into_array(), expected.to_array());
96 }
97
98 #[test]
99 fn test_mul_const() {
100 let dict = unsafe {
102 DictArray::new_unchecked(
103 buffer![0u32, 1, 2, 1, 0].into_array(),
104 buffer![2i32, 3, 5].into_array(),
105 )
106 .set_all_values_referenced(true)
107 };
108
109 let res = numeric(
110 dict.as_ref(),
111 ConstantArray::new(10i32, 5).as_ref(),
112 NumericOperator::Mul,
113 )
114 .unwrap();
115
116 let expected = PrimitiveArray::from_iter([20i32, 30, 50, 30, 20]);
117 assert_arrays_eq!(res.to_canonical().into_array(), expected.to_array());
118 }
119
120 #[test]
121 fn test_no_pushdown_when_not_all_values_referenced() {
122 let dict = DictArray::try_new(
124 buffer![0u32, 1, 0, 1].into_array(),
125 buffer![10i32, 20, 30].into_array(), )
127 .unwrap();
128
129 let res = numeric(
131 dict.as_ref(),
132 ConstantArray::new(5i32, 4).as_ref(),
133 NumericOperator::Add,
134 )
135 .unwrap();
136
137 let expected = PrimitiveArray::from_iter([15i32, 25, 15, 25]);
139 assert_arrays_eq!(res.to_canonical().into_array(), expected.to_array());
140 }
141
142 #[test]
143 fn test_sub_const() {
144 let dict = unsafe {
146 DictArray::new_unchecked(
147 buffer![0u32, 1, 2].into_array(),
148 buffer![100i32, 50, 25].into_array(),
149 )
150 .set_all_values_referenced(true)
151 };
152
153 let res = numeric(
154 dict.as_ref(),
155 ConstantArray::new(10i32, 3).as_ref(),
156 NumericOperator::Sub,
157 )
158 .unwrap();
159
160 let expected = PrimitiveArray::from_iter([90i32, 40, 15]);
161 assert_arrays_eq!(res.to_canonical().into_array(), expected.to_array());
162 }
163}