vortex_array/compute/
numeric.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::{NumericOperator, Scalar};
11
12use crate::arrays::ConstantArray;
13use crate::arrow::{Datum, from_arrow_array_with_len};
14use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
15use crate::vtable::VTable;
16use crate::{Array, ArrayRef, IntoArray};
17
18/// Point-wise add two numeric arrays.
19///
20/// Errs at runtime if the sum would overflow or underflow.
21///
22/// The result is null at any index that either input is null.
23pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
24    numeric(lhs, rhs, NumericOperator::Add)
25}
26
27/// Point-wise add a scalar value to this array on the right-hand-side.
28pub fn add_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
29    numeric(
30        lhs,
31        &ConstantArray::new(rhs, lhs.len()).into_array(),
32        NumericOperator::Add,
33    )
34}
35
36/// Point-wise subtract two numeric arrays.
37pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
38    numeric(lhs, rhs, NumericOperator::Sub)
39}
40
41/// Point-wise subtract a scalar value from this array on the right-hand-side.
42pub fn sub_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
43    numeric(
44        lhs,
45        &ConstantArray::new(rhs, lhs.len()).into_array(),
46        NumericOperator::Sub,
47    )
48}
49
50/// Point-wise multiply two numeric arrays.
51pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
52    numeric(lhs, rhs, NumericOperator::Mul)
53}
54
55/// Point-wise multiply a scalar value into this array on the right-hand-side.
56pub fn mul_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
57    numeric(
58        lhs,
59        &ConstantArray::new(rhs, lhs.len()).into_array(),
60        NumericOperator::Mul,
61    )
62}
63
64/// Point-wise divide two numeric arrays.
65pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
66    numeric(lhs, rhs, NumericOperator::Div)
67}
68
69/// Point-wise divide a scalar value into this array on the right-hand-side.
70pub fn div_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
71    numeric(
72        lhs,
73        &ConstantArray::new(rhs, lhs.len()).into_array(),
74        NumericOperator::Mul,
75    )
76}
77
78/// Point-wise numeric operation between two arrays of the same type and length.
79pub fn numeric(lhs: &dyn Array, rhs: &dyn Array, op: NumericOperator) -> VortexResult<ArrayRef> {
80    NUMERIC_FN
81        .invoke(&InvocationArgs {
82            inputs: &[lhs.into(), rhs.into()],
83            options: &op,
84        })?
85        .unwrap_array()
86}
87
88pub struct NumericKernelRef(ArcRef<dyn Kernel>);
89inventory::collect!(NumericKernelRef);
90
91pub trait NumericKernel: VTable {
92    fn numeric(
93        &self,
94        array: &Self::Array,
95        other: &dyn Array,
96        op: NumericOperator,
97    ) -> VortexResult<Option<ArrayRef>>;
98}
99
100#[derive(Debug)]
101pub struct NumericKernelAdapter<V: VTable>(pub V);
102
103impl<V: VTable + NumericKernel> NumericKernelAdapter<V> {
104    pub const fn lift(&'static self) -> NumericKernelRef {
105        NumericKernelRef(ArcRef::new_ref(self))
106    }
107}
108
109impl<V: VTable + NumericKernel> Kernel for NumericKernelAdapter<V> {
110    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
111        let inputs = NumericArgs::try_from(args)?;
112        let Some(lhs) = inputs.lhs.as_opt::<V>() else {
113            return Ok(None);
114        };
115        Ok(V::numeric(&self.0, lhs, inputs.rhs, inputs.operator)?.map(|array| array.into()))
116    }
117}
118
119pub static NUMERIC_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
120    let compute = ComputeFn::new("numeric".into(), ArcRef::new_ref(&Numeric));
121    for kernel in inventory::iter::<NumericKernelRef> {
122        compute.register_kernel(kernel.0.clone());
123    }
124    compute
125});
126
127struct Numeric;
128
129impl ComputeFnVTable for Numeric {
130    fn invoke(
131        &self,
132        args: &InvocationArgs,
133        kernels: &[ArcRef<dyn Kernel>],
134    ) -> VortexResult<Output> {
135        let NumericArgs { lhs, rhs, operator } = NumericArgs::try_from(args)?;
136
137        // Check if LHS supports the operation directly.
138        for kernel in kernels {
139            if let Some(output) = kernel.invoke(args)? {
140                return Ok(output);
141            }
142        }
143        if let Some(output) = lhs.invoke(&NUMERIC_FN, args)? {
144            return Ok(output);
145        }
146
147        // Check if RHS supports the operation directly.
148        let inverted_args = InvocationArgs {
149            inputs: &[rhs.into(), lhs.into()],
150            options: &operator.swap(),
151        };
152        for kernel in kernels {
153            if let Some(output) = kernel.invoke(&inverted_args)? {
154                return Ok(output);
155            }
156        }
157        if let Some(output) = rhs.invoke(&NUMERIC_FN, &inverted_args)? {
158            return Ok(output);
159        }
160
161        log::debug!(
162            "No numeric implementation found for LHS {}, RHS {}, and operator {:?}",
163            lhs.encoding_id(),
164            rhs.encoding_id(),
165            operator,
166        );
167
168        // If neither side implements the trait, then we delegate to Arrow compute.
169        Ok(arrow_numeric(lhs, rhs, operator)?.into())
170    }
171
172    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
173        let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
174        if !matches!(
175            (lhs.dtype(), rhs.dtype()),
176            (DType::Primitive(..), DType::Primitive(..)) | (DType::Decimal(..), DType::Decimal(..))
177        ) || !lhs.dtype().eq_ignore_nullability(rhs.dtype())
178        {
179            vortex_bail!(
180                "Numeric operations are only supported on two arrays sharing the same numeric type: {} {}",
181                lhs.dtype(),
182                rhs.dtype()
183            )
184        }
185        Ok(lhs.dtype().union_nullability(rhs.dtype().nullability()))
186    }
187
188    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
189        let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
190        if lhs.len() != rhs.len() {
191            vortex_bail!(
192                "Numeric operations aren't supported on arrays of different lengths {} {}",
193                lhs.len(),
194                rhs.len()
195            )
196        }
197        Ok(lhs.len())
198    }
199
200    fn is_elementwise(&self) -> bool {
201        true
202    }
203}
204
205struct NumericArgs<'a> {
206    lhs: &'a dyn Array,
207    rhs: &'a dyn Array,
208    operator: NumericOperator,
209}
210
211impl<'a> TryFrom<&InvocationArgs<'a>> for NumericArgs<'a> {
212    type Error = VortexError;
213
214    fn try_from(args: &InvocationArgs<'a>) -> VortexResult<Self> {
215        if args.inputs.len() != 2 {
216            vortex_bail!("Numeric operations require exactly 2 inputs");
217        }
218        let lhs = args.inputs[0]
219            .array()
220            .ok_or_else(|| vortex_err!("LHS is not an array"))?;
221        let rhs = args.inputs[1]
222            .array()
223            .ok_or_else(|| vortex_err!("RHS is not an array"))?;
224        let operator = *args
225            .options
226            .as_any()
227            .downcast_ref::<NumericOperator>()
228            .ok_or_else(|| vortex_err!("Operator is not a numeric operator"))?;
229        Ok(Self { lhs, rhs, operator })
230    }
231}
232
233impl Options for NumericOperator {
234    fn as_any(&self) -> &dyn Any {
235        self
236    }
237}
238
239/// Implementation of `BinaryNumericFn` using the Arrow crate.
240///
241/// Note that other encodings should handle a constant RHS value, so we can assume here that
242/// the RHS is not constant and expand to a full array.
243fn arrow_numeric(
244    lhs: &dyn Array,
245    rhs: &dyn Array,
246    operator: NumericOperator,
247) -> VortexResult<ArrayRef> {
248    let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
249    let len = lhs.len();
250
251    let left = Datum::try_new(lhs)?;
252    let right = Datum::try_new(rhs)?;
253
254    let array = match operator {
255        NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
256        NumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
257        NumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?,
258        NumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
259        NumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
260        NumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
261    };
262
263    from_arrow_array_with_len(array.as_ref(), len, nullable)
264}
265
266#[cfg(test)]
267mod test {
268    use vortex_buffer::buffer;
269    use vortex_scalar::Scalar;
270
271    use crate::IntoArray;
272    use crate::arrays::PrimitiveArray;
273    use crate::canonical::ToCanonical;
274    use crate::compute::sub_scalar;
275
276    #[test]
277    fn test_scalar_subtract_unsigned() {
278        let values = buffer![1u16, 2, 3].into_array();
279        let results = sub_scalar(&values, 1u16.into())
280            .unwrap()
281            .to_primitive()
282            .unwrap()
283            .as_slice::<u16>()
284            .to_vec();
285        assert_eq!(results, &[0u16, 1, 2]);
286    }
287
288    #[test]
289    fn test_scalar_subtract_signed() {
290        let values = buffer![1i64, 2, 3].into_array();
291        let results = sub_scalar(&values, (-1i64).into())
292            .unwrap()
293            .to_primitive()
294            .unwrap()
295            .as_slice::<i64>()
296            .to_vec();
297        assert_eq!(results, &[2i64, 3, 4]);
298    }
299
300    #[test]
301    fn test_scalar_subtract_nullable() {
302        let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
303        let result = sub_scalar(values.as_ref(), Some(1u16).into())
304            .unwrap()
305            .to_primitive()
306            .unwrap();
307
308        let actual = (0..result.len())
309            .map(|index| result.scalar_at(index).unwrap())
310            .collect::<Vec<_>>();
311        assert_eq!(
312            actual,
313            vec![
314                Scalar::from(Some(0u16)),
315                Scalar::from(Some(1u16)),
316                Scalar::from(None::<u16>),
317                Scalar::from(Some(2u16))
318            ]
319        );
320    }
321
322    #[test]
323    fn test_scalar_subtract_float() {
324        let values = buffer![1.0f64, 2.0, 3.0].into_array();
325        let to_subtract = -1f64;
326        let results = sub_scalar(&values, to_subtract.into())
327            .unwrap()
328            .to_primitive()
329            .unwrap()
330            .as_slice::<f64>()
331            .to_vec();
332        assert_eq!(results, &[2.0f64, 3.0, 4.0]);
333    }
334
335    #[test]
336    fn test_scalar_subtract_float_underflow_is_ok() {
337        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
338        let _results = sub_scalar(&values, 1.0f32.into()).unwrap();
339        let _results = sub_scalar(&values, f32::MAX.into()).unwrap();
340    }
341
342    #[test]
343    fn test_scalar_subtract_type_mismatch_fails() {
344        let values = buffer![1u64, 2, 3].into_array();
345        // Subtracting incompatible dtypes should fail
346        let _results =
347            sub_scalar(&values, 1.5f64.into()).expect_err("Expected type mismatch error");
348    }
349}