1use vortex_dtype::{DType, PType};
2use vortex_error::{VortexExpect, VortexResult, vortex_bail};
3use vortex_scalar::{BinaryNumericOperator, Scalar};
4
5use crate::arrays::ConstantArray;
6use crate::arrow::{Datum, from_arrow_array_with_len};
7use crate::encoding::Encoding;
8use crate::{Array, ArrayRef};
9
10pub trait BinaryNumericFn<A> {
11 fn binary_numeric(
12 &self,
13 array: A,
14 other: &dyn Array,
15 op: BinaryNumericOperator,
16 ) -> VortexResult<Option<ArrayRef>>;
17}
18
19impl<E: Encoding> BinaryNumericFn<&dyn Array> for E
20where
21 E: for<'a> BinaryNumericFn<&'a E::Array>,
22{
23 fn binary_numeric(
24 &self,
25 lhs: &dyn Array,
26 rhs: &dyn Array,
27 op: BinaryNumericOperator,
28 ) -> VortexResult<Option<ArrayRef>> {
29 let array_ref = lhs
30 .as_any()
31 .downcast_ref::<E::Array>()
32 .vortex_expect("Failed to downcast array");
33 BinaryNumericFn::binary_numeric(self, array_ref, rhs, op)
34 }
35}
36
37pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
39 binary_numeric(lhs, rhs, BinaryNumericOperator::Add)
40}
41
42pub fn add_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
44 binary_numeric(
45 lhs,
46 &ConstantArray::new(rhs, lhs.len()).into_array(),
47 BinaryNumericOperator::Add,
48 )
49}
50
51pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
53 binary_numeric(lhs, rhs, BinaryNumericOperator::Sub)
54}
55
56pub fn sub_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
58 binary_numeric(
59 lhs,
60 &ConstantArray::new(rhs, lhs.len()).into_array(),
61 BinaryNumericOperator::Sub,
62 )
63}
64
65pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
67 binary_numeric(lhs, rhs, BinaryNumericOperator::Mul)
68}
69
70pub fn mul_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
72 binary_numeric(
73 lhs,
74 &ConstantArray::new(rhs, lhs.len()).into_array(),
75 BinaryNumericOperator::Mul,
76 )
77}
78
79pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
81 binary_numeric(lhs, rhs, BinaryNumericOperator::Div)
82}
83
84pub fn div_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
86 binary_numeric(
87 lhs,
88 &ConstantArray::new(rhs, lhs.len()).into_array(),
89 BinaryNumericOperator::Mul,
90 )
91}
92
93pub fn binary_numeric(
94 lhs: &dyn Array,
95 rhs: &dyn Array,
96 op: BinaryNumericOperator,
97) -> VortexResult<ArrayRef> {
98 if lhs.len() != rhs.len() {
99 vortex_bail!(
100 "Numeric operations aren't supported on arrays of different lengths {} {}",
101 lhs.len(),
102 rhs.len()
103 )
104 }
105 if !matches!(lhs.dtype(), DType::Primitive(_, _))
106 || !matches!(rhs.dtype(), DType::Primitive(_, _))
107 || !lhs.dtype().eq_ignore_nullability(rhs.dtype())
108 {
109 vortex_bail!(
110 "Numeric operations are only supported on two arrays sharing the same primitive-type: {} {}",
111 lhs.dtype(),
112 rhs.dtype()
113 )
114 }
115
116 if let Some(fun) = lhs.vtable().binary_numeric_fn() {
118 if let Some(result) = fun.binary_numeric(lhs, rhs, op)? {
119 return Ok(check_numeric_result(result, lhs, rhs));
120 }
121 }
122
123 if let Some(fun) = rhs.vtable().binary_numeric_fn() {
125 if let Some(result) = fun.binary_numeric(rhs, lhs, op.swap())? {
126 return Ok(check_numeric_result(result, lhs, rhs));
127 }
128 }
129
130 log::debug!(
131 "No numeric implementation found for LHS {}, RHS {}, and operator {:?}",
132 lhs.encoding(),
133 rhs.encoding(),
134 op,
135 );
136
137 arrow_numeric(lhs, rhs, op)
139}
140
141fn arrow_numeric(
146 lhs: &dyn Array,
147 rhs: &dyn Array,
148 operator: BinaryNumericOperator,
149) -> VortexResult<ArrayRef> {
150 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
151 let len = lhs.len();
152
153 let left = Datum::try_new(lhs.to_array())?;
154 let right = Datum::try_new(rhs.to_array())?;
155
156 let array = match operator {
157 BinaryNumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
158 BinaryNumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
159 BinaryNumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?,
160 BinaryNumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
161 BinaryNumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
162 BinaryNumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
163 };
164
165 Ok(check_numeric_result(
166 from_arrow_array_with_len(array, len, nullable)?,
167 lhs,
168 rhs,
169 ))
170}
171
172#[inline(always)]
173fn check_numeric_result(result: ArrayRef, lhs: &dyn Array, rhs: &dyn Array) -> ArrayRef {
174 debug_assert_eq!(
175 result.len(),
176 lhs.len(),
177 "Numeric operation length mismatch {}",
178 rhs.encoding()
179 );
180 debug_assert_eq!(
181 result.dtype(),
182 &DType::Primitive(
183 PType::try_from(lhs.dtype())
184 .vortex_expect("Numeric operation DType failed to convert to PType"),
185 (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()
186 ),
187 "Numeric operation dtype mismatch {}",
188 rhs.encoding()
189 );
190 result
191}
192
193#[cfg(feature = "test-harness")]
194pub mod test_harness {
195 use num_traits::Num;
196 use vortex_dtype::NativePType;
197 use vortex_error::{VortexResult, vortex_err};
198 use vortex_scalar::{BinaryNumericOperator, PrimitiveScalar, Scalar};
199
200 use crate::arrays::ConstantArray;
201 use crate::compute::{binary_numeric, scalar_at};
202 use crate::{Array, ArrayRef};
203
204 #[allow(clippy::unwrap_used)]
205 fn to_vec_of_scalar(array: &dyn Array) -> Vec<Scalar> {
206 (0..array.len())
208 .map(|index| scalar_at(array, index))
209 .collect::<VortexResult<Vec<_>>>()
210 .unwrap()
211 }
212
213 #[allow(clippy::unwrap_used)]
214 pub fn test_binary_numeric<T: NativePType + Num + Copy>(array: ArrayRef)
215 where
216 Scalar: From<T>,
217 {
218 let canonicalized_array = array.to_canonical().unwrap().into_primitive().unwrap();
219 let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
220
221 let one = T::from(1)
222 .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
223 .unwrap();
224 let scalar_one = Scalar::from(one).cast(array.dtype()).unwrap();
225
226 let operators: [BinaryNumericOperator; 6] = [
227 BinaryNumericOperator::Add,
228 BinaryNumericOperator::Sub,
229 BinaryNumericOperator::RSub,
230 BinaryNumericOperator::Mul,
231 BinaryNumericOperator::Div,
232 BinaryNumericOperator::RDiv,
233 ];
234
235 for operator in operators {
236 assert_eq!(
237 to_vec_of_scalar(
238 &binary_numeric(
239 &array,
240 &ConstantArray::new(scalar_one.clone(), array.len()).into_array(),
241 operator
242 )
243 .unwrap()
244 ),
245 original_values
246 .iter()
247 .map(|x| x
248 .as_primitive()
249 .checked_binary_numeric(&scalar_one.as_primitive(), operator)
250 .unwrap())
251 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
252 .collect::<Vec<Scalar>>(),
253 "({:?}) {} (Constant array of {}) did not produce expected results",
254 array,
255 operator,
256 scalar_one,
257 );
258
259 assert_eq!(
260 to_vec_of_scalar(
261 &binary_numeric(
262 &ConstantArray::new(scalar_one.clone(), array.len()).into_array(),
263 &array,
264 operator
265 )
266 .unwrap()
267 ),
268 original_values
269 .iter()
270 .map(|x| scalar_one
271 .as_primitive()
272 .checked_binary_numeric(&x.as_primitive(), operator)
273 .unwrap())
274 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
275 .collect::<Vec<_>>(),
276 "(Constant array of {}) {} ({:?}) did not produce expected results",
277 scalar_one,
278 operator,
279 array,
280 );
281 }
282 }
283}
284
285#[cfg(test)]
286mod test {
287 use vortex_buffer::buffer;
288 use vortex_scalar::Scalar;
289
290 use crate::IntoArray;
291 use crate::array::Array;
292 use crate::arrays::PrimitiveArray;
293 use crate::canonical::ToCanonical;
294 use crate::compute::{scalar_at, sub_scalar};
295
296 #[test]
297 fn test_scalar_subtract_unsigned() {
298 let values = buffer![1u16, 2, 3].into_array();
299 let results = sub_scalar(&values, 1u16.into())
300 .unwrap()
301 .to_primitive()
302 .unwrap()
303 .as_slice::<u16>()
304 .to_vec();
305 assert_eq!(results, &[0u16, 1, 2]);
306 }
307
308 #[test]
309 fn test_scalar_subtract_signed() {
310 let values = buffer![1i64, 2, 3].into_array();
311 let results = sub_scalar(&values, (-1i64).into())
312 .unwrap()
313 .to_primitive()
314 .unwrap()
315 .as_slice::<i64>()
316 .to_vec();
317 assert_eq!(results, &[2i64, 3, 4]);
318 }
319
320 #[test]
321 fn test_scalar_subtract_nullable() {
322 let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
323 let result = sub_scalar(&values, Some(1u16).into())
324 .unwrap()
325 .to_primitive()
326 .unwrap();
327
328 let actual = (0..result.len())
329 .map(|index| scalar_at(&result, index).unwrap())
330 .collect::<Vec<_>>();
331 assert_eq!(
332 actual,
333 vec![
334 Scalar::from(Some(0u16)),
335 Scalar::from(Some(1u16)),
336 Scalar::from(None::<u16>),
337 Scalar::from(Some(2u16))
338 ]
339 );
340 }
341
342 #[test]
343 fn test_scalar_subtract_float() {
344 let values = buffer![1.0f64, 2.0, 3.0].into_array();
345 let to_subtract = -1f64;
346 let results = sub_scalar(&values, to_subtract.into())
347 .unwrap()
348 .to_primitive()
349 .unwrap()
350 .as_slice::<f64>()
351 .to_vec();
352 assert_eq!(results, &[2.0f64, 3.0, 4.0]);
353 }
354
355 #[test]
356 fn test_scalar_subtract_float_underflow_is_ok() {
357 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
358 let _results = sub_scalar(&values, 1.0f32.into()).unwrap();
359 let _results = sub_scalar(&values, f32::MAX.into()).unwrap();
360 }
361
362 #[test]
363 fn test_scalar_subtract_type_mismatch_fails() {
364 let values = buffer![1u64, 2, 3].into_array();
365 let _results =
367 sub_scalar(&values, 1.5f64.into()).expect_err("Expected type mismatch error");
368 }
369}