1use vortex_dtype::{DType, PType};
2use vortex_error::{VortexExpect, VortexResult, vortex_bail};
3use vortex_scalar::{BinaryNumericOperator, Scalar};
4
5use crate::arrays::ConstantArray;
6use crate::arrow::{Datum, from_arrow_array_with_len};
7use crate::encoding::Encoding;
8use crate::{Array, ArrayRef};
9
10pub trait BinaryNumericFn<A> {
11 fn binary_numeric(
12 &self,
13 array: A,
14 other: &dyn Array,
15 op: BinaryNumericOperator,
16 ) -> VortexResult<Option<ArrayRef>>;
17}
18
19impl<E: Encoding> BinaryNumericFn<&dyn Array> for E
20where
21 E: for<'a> BinaryNumericFn<&'a E::Array>,
22{
23 fn binary_numeric(
24 &self,
25 lhs: &dyn Array,
26 rhs: &dyn Array,
27 op: BinaryNumericOperator,
28 ) -> VortexResult<Option<ArrayRef>> {
29 let array_ref = lhs
30 .as_any()
31 .downcast_ref::<E::Array>()
32 .vortex_expect("Failed to downcast array");
33 BinaryNumericFn::binary_numeric(self, array_ref, rhs, op)
34 }
35}
36
37pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
39 binary_numeric(lhs, rhs, BinaryNumericOperator::Add)
40}
41
42pub fn add_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
44 binary_numeric(
45 lhs,
46 &ConstantArray::new(rhs, lhs.len()).into_array(),
47 BinaryNumericOperator::Add,
48 )
49}
50
51pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
53 binary_numeric(lhs, rhs, BinaryNumericOperator::Sub)
54}
55
56pub fn sub_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
58 binary_numeric(
59 lhs,
60 &ConstantArray::new(rhs, lhs.len()).into_array(),
61 BinaryNumericOperator::Sub,
62 )
63}
64
65pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
67 binary_numeric(lhs, rhs, BinaryNumericOperator::Mul)
68}
69
70pub fn mul_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
72 binary_numeric(
73 lhs,
74 &ConstantArray::new(rhs, lhs.len()).into_array(),
75 BinaryNumericOperator::Mul,
76 )
77}
78
79pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
81 binary_numeric(lhs, rhs, BinaryNumericOperator::Div)
82}
83
84pub fn div_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
86 binary_numeric(
87 lhs,
88 &ConstantArray::new(rhs, lhs.len()).into_array(),
89 BinaryNumericOperator::Mul,
90 )
91}
92
93pub fn binary_numeric(
94 lhs: &dyn Array,
95 rhs: &dyn Array,
96 op: BinaryNumericOperator,
97) -> VortexResult<ArrayRef> {
98 if lhs.len() != rhs.len() {
99 vortex_bail!(
100 "Numeric operations aren't supported on arrays of different lengths {} {}",
101 lhs.len(),
102 rhs.len()
103 )
104 }
105 if !matches!(lhs.dtype(), DType::Primitive(_, _))
106 || !matches!(rhs.dtype(), DType::Primitive(_, _))
107 || !lhs.dtype().eq_ignore_nullability(rhs.dtype())
108 {
109 vortex_bail!(
110 "Numeric operations are only supported on two arrays sharing the same primitive-type: {} {}",
111 lhs.dtype(),
112 rhs.dtype()
113 )
114 }
115
116 if let Some(fun) = lhs.vtable().binary_numeric_fn() {
118 if let Some(result) = fun.binary_numeric(lhs, rhs, op)? {
119 return Ok(check_numeric_result(result, lhs, rhs));
120 }
121 }
122
123 if let Some(fun) = rhs.vtable().binary_numeric_fn() {
125 if let Some(result) = fun.binary_numeric(rhs, lhs, op.swap())? {
126 return Ok(check_numeric_result(result, lhs, rhs));
127 }
128 }
129
130 log::debug!(
131 "No numeric implementation found for LHS {}, RHS {}, and operator {:?}",
132 lhs.encoding(),
133 rhs.encoding(),
134 op,
135 );
136
137 arrow_numeric(lhs, rhs, op)
139}
140
141fn arrow_numeric(
146 lhs: &dyn Array,
147 rhs: &dyn Array,
148 operator: BinaryNumericOperator,
149) -> VortexResult<ArrayRef> {
150 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
151 let len = lhs.len();
152
153 let left = Datum::try_new(lhs)?;
154 let right = Datum::try_new(rhs)?;
155
156 let array = match operator {
157 BinaryNumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
158 BinaryNumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
159 BinaryNumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?,
160 BinaryNumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
161 BinaryNumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
162 BinaryNumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
163 };
164
165 Ok(check_numeric_result(
166 from_arrow_array_with_len(array, len, nullable)?,
167 lhs,
168 rhs,
169 ))
170}
171
172#[inline(always)]
173fn check_numeric_result(result: ArrayRef, lhs: &dyn Array, rhs: &dyn Array) -> ArrayRef {
174 debug_assert_eq!(
175 result.len(),
176 lhs.len(),
177 "Numeric operation length mismatch {}",
178 rhs.encoding()
179 );
180 debug_assert_eq!(
181 result.dtype(),
182 &DType::Primitive(
183 PType::try_from(lhs.dtype())
184 .vortex_expect("Numeric operation DType failed to convert to PType"),
185 (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()
186 ),
187 "Numeric operation dtype mismatch {}",
188 rhs.encoding()
189 );
190 result
191}
192
193#[cfg(test)]
194mod test {
195 use vortex_buffer::buffer;
196 use vortex_scalar::Scalar;
197
198 use crate::IntoArray;
199 use crate::array::Array;
200 use crate::arrays::PrimitiveArray;
201 use crate::canonical::ToCanonical;
202 use crate::compute::{scalar_at, sub_scalar};
203
204 #[test]
205 fn test_scalar_subtract_unsigned() {
206 let values = buffer![1u16, 2, 3].into_array();
207 let results = sub_scalar(&values, 1u16.into())
208 .unwrap()
209 .to_primitive()
210 .unwrap()
211 .as_slice::<u16>()
212 .to_vec();
213 assert_eq!(results, &[0u16, 1, 2]);
214 }
215
216 #[test]
217 fn test_scalar_subtract_signed() {
218 let values = buffer![1i64, 2, 3].into_array();
219 let results = sub_scalar(&values, (-1i64).into())
220 .unwrap()
221 .to_primitive()
222 .unwrap()
223 .as_slice::<i64>()
224 .to_vec();
225 assert_eq!(results, &[2i64, 3, 4]);
226 }
227
228 #[test]
229 fn test_scalar_subtract_nullable() {
230 let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
231 let result = sub_scalar(&values, Some(1u16).into())
232 .unwrap()
233 .to_primitive()
234 .unwrap();
235
236 let actual = (0..result.len())
237 .map(|index| scalar_at(&result, index).unwrap())
238 .collect::<Vec<_>>();
239 assert_eq!(
240 actual,
241 vec![
242 Scalar::from(Some(0u16)),
243 Scalar::from(Some(1u16)),
244 Scalar::from(None::<u16>),
245 Scalar::from(Some(2u16))
246 ]
247 );
248 }
249
250 #[test]
251 fn test_scalar_subtract_float() {
252 let values = buffer![1.0f64, 2.0, 3.0].into_array();
253 let to_subtract = -1f64;
254 let results = sub_scalar(&values, to_subtract.into())
255 .unwrap()
256 .to_primitive()
257 .unwrap()
258 .as_slice::<f64>()
259 .to_vec();
260 assert_eq!(results, &[2.0f64, 3.0, 4.0]);
261 }
262
263 #[test]
264 fn test_scalar_subtract_float_underflow_is_ok() {
265 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
266 let _results = sub_scalar(&values, 1.0f32.into()).unwrap();
267 let _results = sub_scalar(&values, f32::MAX.into()).unwrap();
268 }
269
270 #[test]
271 fn test_scalar_subtract_type_mismatch_fails() {
272 let values = buffer![1u64, 2, 3].into_array();
273 let _results =
275 sub_scalar(&values, 1.5f64.into()).expect_err("Expected type mismatch error");
276 }
277}