vortex_array/compute/
binary_numeric.rs

1use vortex_dtype::{DType, PType};
2use vortex_error::{VortexExpect, VortexResult, vortex_bail};
3use vortex_scalar::{BinaryNumericOperator, Scalar};
4
5use crate::arrays::ConstantArray;
6use crate::arrow::{Datum, from_arrow_array_with_len};
7use crate::encoding::Encoding;
8use crate::{Array, ArrayRef};
9
10pub trait BinaryNumericFn<A> {
11    fn binary_numeric(
12        &self,
13        array: A,
14        other: &dyn Array,
15        op: BinaryNumericOperator,
16    ) -> VortexResult<Option<ArrayRef>>;
17}
18
19impl<E: Encoding> BinaryNumericFn<&dyn Array> for E
20where
21    E: for<'a> BinaryNumericFn<&'a E::Array>,
22{
23    fn binary_numeric(
24        &self,
25        lhs: &dyn Array,
26        rhs: &dyn Array,
27        op: BinaryNumericOperator,
28    ) -> VortexResult<Option<ArrayRef>> {
29        let array_ref = lhs
30            .as_any()
31            .downcast_ref::<E::Array>()
32            .vortex_expect("Failed to downcast array");
33        BinaryNumericFn::binary_numeric(self, array_ref, rhs, op)
34    }
35}
36
37/// Point-wise add two numeric arrays.
38pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
39    binary_numeric(lhs, rhs, BinaryNumericOperator::Add)
40}
41
42/// Point-wise add a scalar value to this array on the right-hand-side.
43pub fn add_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
44    binary_numeric(
45        lhs,
46        &ConstantArray::new(rhs, lhs.len()).into_array(),
47        BinaryNumericOperator::Add,
48    )
49}
50
51/// Point-wise subtract two numeric arrays.
52pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
53    binary_numeric(lhs, rhs, BinaryNumericOperator::Sub)
54}
55
56/// Point-wise subtract a scalar value from this array on the right-hand-side.
57pub fn sub_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
58    binary_numeric(
59        lhs,
60        &ConstantArray::new(rhs, lhs.len()).into_array(),
61        BinaryNumericOperator::Sub,
62    )
63}
64
65/// Point-wise multiply two numeric arrays.
66pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
67    binary_numeric(lhs, rhs, BinaryNumericOperator::Mul)
68}
69
70/// Point-wise multiply a scalar value into this array on the right-hand-side.
71pub fn mul_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
72    binary_numeric(
73        lhs,
74        &ConstantArray::new(rhs, lhs.len()).into_array(),
75        BinaryNumericOperator::Mul,
76    )
77}
78
79/// Point-wise divide two numeric arrays.
80pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
81    binary_numeric(lhs, rhs, BinaryNumericOperator::Div)
82}
83
84/// Point-wise divide a scalar value into this array on the right-hand-side.
85pub fn div_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
86    binary_numeric(
87        lhs,
88        &ConstantArray::new(rhs, lhs.len()).into_array(),
89        BinaryNumericOperator::Mul,
90    )
91}
92
93pub fn binary_numeric(
94    lhs: &dyn Array,
95    rhs: &dyn Array,
96    op: BinaryNumericOperator,
97) -> VortexResult<ArrayRef> {
98    if lhs.len() != rhs.len() {
99        vortex_bail!(
100            "Numeric operations aren't supported on arrays of different lengths {} {}",
101            lhs.len(),
102            rhs.len()
103        )
104    }
105    if !matches!(lhs.dtype(), DType::Primitive(_, _))
106        || !matches!(rhs.dtype(), DType::Primitive(_, _))
107        || !lhs.dtype().eq_ignore_nullability(rhs.dtype())
108    {
109        vortex_bail!(
110            "Numeric operations are only supported on two arrays sharing the same primitive-type: {} {}",
111            lhs.dtype(),
112            rhs.dtype()
113        )
114    }
115
116    // Check if LHS supports the operation directly.
117    if let Some(fun) = lhs.vtable().binary_numeric_fn() {
118        if let Some(result) = fun.binary_numeric(lhs, rhs, op)? {
119            return Ok(check_numeric_result(result, lhs, rhs));
120        }
121    }
122
123    // Check if RHS supports the operation directly.
124    if let Some(fun) = rhs.vtable().binary_numeric_fn() {
125        if let Some(result) = fun.binary_numeric(rhs, lhs, op.swap())? {
126            return Ok(check_numeric_result(result, lhs, rhs));
127        }
128    }
129
130    log::debug!(
131        "No numeric implementation found for LHS {}, RHS {}, and operator {:?}",
132        lhs.encoding(),
133        rhs.encoding(),
134        op,
135    );
136
137    // If neither side implements the trait, then we delegate to Arrow compute.
138    arrow_numeric(lhs, rhs, op)
139}
140
141/// Implementation of `BinaryBooleanFn` using the Arrow crate.
142///
143/// Note that other encodings should handle a constant RHS value, so we can assume here that
144/// the RHS is not constant and expand to a full array.
145fn arrow_numeric(
146    lhs: &dyn Array,
147    rhs: &dyn Array,
148    operator: BinaryNumericOperator,
149) -> VortexResult<ArrayRef> {
150    let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
151    let len = lhs.len();
152
153    let left = Datum::try_new(lhs)?;
154    let right = Datum::try_new(rhs)?;
155
156    let array = match operator {
157        BinaryNumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
158        BinaryNumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
159        BinaryNumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?,
160        BinaryNumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
161        BinaryNumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
162        BinaryNumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
163    };
164
165    Ok(check_numeric_result(
166        from_arrow_array_with_len(array, len, nullable)?,
167        lhs,
168        rhs,
169    ))
170}
171
172#[inline(always)]
173fn check_numeric_result(result: ArrayRef, lhs: &dyn Array, rhs: &dyn Array) -> ArrayRef {
174    debug_assert_eq!(
175        result.len(),
176        lhs.len(),
177        "Numeric operation length mismatch {}",
178        rhs.encoding()
179    );
180    debug_assert_eq!(
181        result.dtype(),
182        &DType::Primitive(
183            PType::try_from(lhs.dtype())
184                .vortex_expect("Numeric operation DType failed to convert to PType"),
185            (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()
186        ),
187        "Numeric operation dtype mismatch {}",
188        rhs.encoding()
189    );
190    result
191}
192
193#[cfg(test)]
194mod test {
195    use vortex_buffer::buffer;
196    use vortex_scalar::Scalar;
197
198    use crate::IntoArray;
199    use crate::array::Array;
200    use crate::arrays::PrimitiveArray;
201    use crate::canonical::ToCanonical;
202    use crate::compute::{scalar_at, sub_scalar};
203
204    #[test]
205    fn test_scalar_subtract_unsigned() {
206        let values = buffer![1u16, 2, 3].into_array();
207        let results = sub_scalar(&values, 1u16.into())
208            .unwrap()
209            .to_primitive()
210            .unwrap()
211            .as_slice::<u16>()
212            .to_vec();
213        assert_eq!(results, &[0u16, 1, 2]);
214    }
215
216    #[test]
217    fn test_scalar_subtract_signed() {
218        let values = buffer![1i64, 2, 3].into_array();
219        let results = sub_scalar(&values, (-1i64).into())
220            .unwrap()
221            .to_primitive()
222            .unwrap()
223            .as_slice::<i64>()
224            .to_vec();
225        assert_eq!(results, &[2i64, 3, 4]);
226    }
227
228    #[test]
229    fn test_scalar_subtract_nullable() {
230        let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
231        let result = sub_scalar(&values, Some(1u16).into())
232            .unwrap()
233            .to_primitive()
234            .unwrap();
235
236        let actual = (0..result.len())
237            .map(|index| scalar_at(&result, index).unwrap())
238            .collect::<Vec<_>>();
239        assert_eq!(
240            actual,
241            vec![
242                Scalar::from(Some(0u16)),
243                Scalar::from(Some(1u16)),
244                Scalar::from(None::<u16>),
245                Scalar::from(Some(2u16))
246            ]
247        );
248    }
249
250    #[test]
251    fn test_scalar_subtract_float() {
252        let values = buffer![1.0f64, 2.0, 3.0].into_array();
253        let to_subtract = -1f64;
254        let results = sub_scalar(&values, to_subtract.into())
255            .unwrap()
256            .to_primitive()
257            .unwrap()
258            .as_slice::<f64>()
259            .to_vec();
260        assert_eq!(results, &[2.0f64, 3.0, 4.0]);
261    }
262
263    #[test]
264    fn test_scalar_subtract_float_underflow_is_ok() {
265        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
266        let _results = sub_scalar(&values, 1.0f32.into()).unwrap();
267        let _results = sub_scalar(&values, f32::MAX.into()).unwrap();
268    }
269
270    #[test]
271    fn test_scalar_subtract_type_mismatch_fails() {
272        let values = buffer![1u64, 2, 3].into_array();
273        // Subtracting incompatible dtypes should fail
274        let _results =
275            sub_scalar(&values, 1.5f64.into()).expect_err("Expected type mismatch error");
276    }
277}