vortex_array/compute/
numeric.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::{NumericOperator, Scalar};
11
12use crate::arrays::ConstantArray;
13use crate::arrow::{Datum, from_arrow_array_with_len};
14use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
15use crate::vtable::VTable;
16use crate::{Array, ArrayRef, IntoArray};
17
18static NUMERIC_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
19    let compute = ComputeFn::new("numeric".into(), ArcRef::new_ref(&Numeric));
20    for kernel in inventory::iter::<NumericKernelRef> {
21        compute.register_kernel(kernel.0.clone());
22    }
23    compute
24});
25
26pub(crate) fn warm_up_vtable() -> usize {
27    NUMERIC_FN.kernels().len()
28}
29
30/// Point-wise add two numeric arrays.
31///
32/// Errs at runtime if the sum would overflow or underflow.
33///
34/// The result is null at any index that either input is null.
35pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
36    numeric(lhs, rhs, NumericOperator::Add)
37}
38
39/// Point-wise add a scalar value to this array on the right-hand-side.
40pub fn add_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
41    numeric(
42        lhs,
43        &ConstantArray::new(rhs, lhs.len()).into_array(),
44        NumericOperator::Add,
45    )
46}
47
48/// Point-wise subtract two numeric arrays.
49pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
50    numeric(lhs, rhs, NumericOperator::Sub)
51}
52
53/// Point-wise subtract a scalar value from this array on the right-hand-side.
54pub fn sub_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
55    numeric(
56        lhs,
57        &ConstantArray::new(rhs, lhs.len()).into_array(),
58        NumericOperator::Sub,
59    )
60}
61
62/// Point-wise multiply two numeric arrays.
63pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
64    numeric(lhs, rhs, NumericOperator::Mul)
65}
66
67/// Point-wise multiply a scalar value into this array on the right-hand-side.
68pub fn mul_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
69    numeric(
70        lhs,
71        &ConstantArray::new(rhs, lhs.len()).into_array(),
72        NumericOperator::Mul,
73    )
74}
75
76/// Point-wise divide two numeric arrays.
77pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
78    numeric(lhs, rhs, NumericOperator::Div)
79}
80
81/// Point-wise divide a scalar value into this array on the right-hand-side.
82pub fn div_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
83    numeric(
84        lhs,
85        &ConstantArray::new(rhs, lhs.len()).into_array(),
86        NumericOperator::Mul,
87    )
88}
89
90/// Point-wise numeric operation between two arrays of the same type and length.
91pub fn numeric(lhs: &dyn Array, rhs: &dyn Array, op: NumericOperator) -> VortexResult<ArrayRef> {
92    NUMERIC_FN
93        .invoke(&InvocationArgs {
94            inputs: &[lhs.into(), rhs.into()],
95            options: &op,
96        })?
97        .unwrap_array()
98}
99
100pub struct NumericKernelRef(ArcRef<dyn Kernel>);
101inventory::collect!(NumericKernelRef);
102
103pub trait NumericKernel: VTable {
104    fn numeric(
105        &self,
106        array: &Self::Array,
107        other: &dyn Array,
108        op: NumericOperator,
109    ) -> VortexResult<Option<ArrayRef>>;
110}
111
112#[derive(Debug)]
113pub struct NumericKernelAdapter<V: VTable>(pub V);
114
115impl<V: VTable + NumericKernel> NumericKernelAdapter<V> {
116    pub const fn lift(&'static self) -> NumericKernelRef {
117        NumericKernelRef(ArcRef::new_ref(self))
118    }
119}
120
121impl<V: VTable + NumericKernel> Kernel for NumericKernelAdapter<V> {
122    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
123        let inputs = NumericArgs::try_from(args)?;
124        let Some(lhs) = inputs.lhs.as_opt::<V>() else {
125            return Ok(None);
126        };
127        Ok(V::numeric(&self.0, lhs, inputs.rhs, inputs.operator)?.map(|array| array.into()))
128    }
129}
130
131struct Numeric;
132
133impl ComputeFnVTable for Numeric {
134    fn invoke(
135        &self,
136        args: &InvocationArgs,
137        kernels: &[ArcRef<dyn Kernel>],
138    ) -> VortexResult<Output> {
139        let NumericArgs { lhs, rhs, operator } = NumericArgs::try_from(args)?;
140
141        for kernel in kernels {
142            if let Some(output) = kernel.invoke(args)? {
143                return Ok(output);
144            }
145        }
146
147        // Check if LHS supports the operation directly.
148        if let Some(output) = lhs.invoke(&NUMERIC_FN, args)? {
149            return Ok(output);
150        }
151
152        // Check if RHS supports the operation directly.
153        let inverted_args = InvocationArgs {
154            inputs: &[rhs.into(), lhs.into()],
155            options: &operator.swap(),
156        };
157        for kernel in kernels {
158            if let Some(output) = kernel.invoke(&inverted_args)? {
159                return Ok(output);
160            }
161        }
162        if let Some(output) = rhs.invoke(&NUMERIC_FN, &inverted_args)? {
163            return Ok(output);
164        }
165
166        log::debug!(
167            "No numeric implementation found for LHS {}, RHS {}, and operator {:?}",
168            lhs.encoding_id(),
169            rhs.encoding_id(),
170            operator,
171        );
172
173        // If neither side implements the trait, then we delegate to Arrow compute.
174        Ok(arrow_numeric(lhs, rhs, operator)?.into())
175    }
176
177    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
178        let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
179        if !matches!(
180            (lhs.dtype(), rhs.dtype()),
181            (DType::Primitive(..), DType::Primitive(..)) | (DType::Decimal(..), DType::Decimal(..))
182        ) || !lhs.dtype().eq_ignore_nullability(rhs.dtype())
183        {
184            vortex_bail!(
185                "Numeric operations are only supported on two arrays sharing the same numeric type: {} {}",
186                lhs.dtype(),
187                rhs.dtype()
188            )
189        }
190        Ok(lhs.dtype().union_nullability(rhs.dtype().nullability()))
191    }
192
193    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
194        let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
195        if lhs.len() != rhs.len() {
196            vortex_bail!(
197                "Numeric operations aren't supported on arrays of different lengths {} {}",
198                lhs.len(),
199                rhs.len()
200            )
201        }
202        Ok(lhs.len())
203    }
204
205    fn is_elementwise(&self) -> bool {
206        true
207    }
208}
209
210struct NumericArgs<'a> {
211    lhs: &'a dyn Array,
212    rhs: &'a dyn Array,
213    operator: NumericOperator,
214}
215
216impl<'a> TryFrom<&InvocationArgs<'a>> for NumericArgs<'a> {
217    type Error = VortexError;
218
219    fn try_from(args: &InvocationArgs<'a>) -> VortexResult<Self> {
220        if args.inputs.len() != 2 {
221            vortex_bail!("Numeric operations require exactly 2 inputs");
222        }
223        let lhs = args.inputs[0]
224            .array()
225            .ok_or_else(|| vortex_err!("LHS is not an array"))?;
226        let rhs = args.inputs[1]
227            .array()
228            .ok_or_else(|| vortex_err!("RHS is not an array"))?;
229        let operator = *args
230            .options
231            .as_any()
232            .downcast_ref::<NumericOperator>()
233            .ok_or_else(|| vortex_err!("Operator is not a numeric operator"))?;
234        Ok(Self { lhs, rhs, operator })
235    }
236}
237
238impl Options for NumericOperator {
239    fn as_any(&self) -> &dyn Any {
240        self
241    }
242}
243
244/// Implementation of `BinaryNumericFn` using the Arrow crate.
245///
246/// Note that other encodings should handle a constant RHS value, so we can assume here that
247/// the RHS is not constant and expand to a full array.
248fn arrow_numeric(
249    lhs: &dyn Array,
250    rhs: &dyn Array,
251    operator: NumericOperator,
252) -> VortexResult<ArrayRef> {
253    let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
254    let len = lhs.len();
255
256    let left = Datum::try_new(lhs)?;
257    let right = Datum::try_new(rhs)?;
258
259    let array = match operator {
260        NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
261        NumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
262        NumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?,
263        NumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
264        NumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
265        NumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
266    };
267
268    Ok(from_arrow_array_with_len(array.as_ref(), len, nullable))
269}
270
271#[cfg(test)]
272mod test {
273    use vortex_buffer::buffer;
274    use vortex_scalar::Scalar;
275
276    use crate::IntoArray;
277    use crate::arrays::PrimitiveArray;
278    use crate::canonical::ToCanonical;
279    use crate::compute::sub_scalar;
280
281    #[test]
282    fn test_scalar_subtract_unsigned() {
283        let values = buffer![1u16, 2, 3].into_array();
284        let results = sub_scalar(&values, 1u16.into())
285            .unwrap()
286            .to_primitive()
287            .as_slice::<u16>()
288            .to_vec();
289        assert_eq!(results, &[0u16, 1, 2]);
290    }
291
292    #[test]
293    fn test_scalar_subtract_signed() {
294        let values = buffer![1i64, 2, 3].into_array();
295        let results = sub_scalar(&values, (-1i64).into())
296            .unwrap()
297            .to_primitive()
298            .as_slice::<i64>()
299            .to_vec();
300        assert_eq!(results, &[2i64, 3, 4]);
301    }
302
303    #[test]
304    fn test_scalar_subtract_nullable() {
305        let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
306        let result = sub_scalar(values.as_ref(), Some(1u16).into())
307            .unwrap()
308            .to_primitive();
309
310        let actual = (0..result.len())
311            .map(|index| result.scalar_at(index))
312            .collect::<Vec<_>>();
313        assert_eq!(
314            actual,
315            vec![
316                Scalar::from(Some(0u16)),
317                Scalar::from(Some(1u16)),
318                Scalar::from(None::<u16>),
319                Scalar::from(Some(2u16))
320            ]
321        );
322    }
323
324    #[test]
325    fn test_scalar_subtract_float() {
326        let values = buffer![1.0f64, 2.0, 3.0].into_array();
327        let to_subtract = -1f64;
328        let results = sub_scalar(&values, to_subtract.into())
329            .unwrap()
330            .to_primitive()
331            .as_slice::<f64>()
332            .to_vec();
333        assert_eq!(results, &[2.0f64, 3.0, 4.0]);
334    }
335
336    #[test]
337    fn test_scalar_subtract_float_underflow_is_ok() {
338        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
339        let _results = sub_scalar(&values, 1.0f32.into()).unwrap();
340        let _results = sub_scalar(&values, f32::MAX.into()).unwrap();
341    }
342
343    #[test]
344    fn test_scalar_subtract_type_mismatch_fails() {
345        let values = buffer![1u64, 2, 3].into_array();
346        // Subtracting incompatible dtypes should fail
347        let _results =
348            sub_scalar(&values, 1.5f64.into()).expect_err("Expected type mismatch error");
349    }
350}