vortex_array/compute/
numeric.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_error::VortexError;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_scalar::NumericOperator;
14use vortex_scalar::Scalar;
15
16use crate::Array;
17use crate::ArrayRef;
18use crate::IntoArray;
19use crate::arrays::ConstantArray;
20use crate::arrow::Datum;
21use crate::arrow::from_arrow_array_with_len;
22use crate::compute::ComputeFn;
23use crate::compute::ComputeFnVTable;
24use crate::compute::InvocationArgs;
25use crate::compute::Kernel;
26use crate::compute::Options;
27use crate::compute::Output;
28use crate::vtable::VTable;
29
30static NUMERIC_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
31    let compute = ComputeFn::new("numeric".into(), ArcRef::new_ref(&Numeric));
32    for kernel in inventory::iter::<NumericKernelRef> {
33        compute.register_kernel(kernel.0.clone());
34    }
35    compute
36});
37
38pub(crate) fn warm_up_vtable() -> usize {
39    NUMERIC_FN.kernels().len()
40}
41
42/// Point-wise add two numeric arrays.
43///
44/// Errs at runtime if the sum would overflow or underflow.
45///
46/// The result is null at any index that either input is null.
47pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
48    numeric(lhs, rhs, NumericOperator::Add)
49}
50
51/// Point-wise add a scalar value to this array on the right-hand-side.
52pub fn add_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
53    numeric(
54        lhs,
55        &ConstantArray::new(rhs, lhs.len()).into_array(),
56        NumericOperator::Add,
57    )
58}
59
60/// Point-wise subtract two numeric arrays.
61pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
62    numeric(lhs, rhs, NumericOperator::Sub)
63}
64
65/// Point-wise subtract a scalar value from this array on the right-hand-side.
66pub fn sub_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
67    numeric(
68        lhs,
69        &ConstantArray::new(rhs, lhs.len()).into_array(),
70        NumericOperator::Sub,
71    )
72}
73
74/// Point-wise multiply two numeric arrays.
75pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
76    numeric(lhs, rhs, NumericOperator::Mul)
77}
78
79/// Point-wise multiply a scalar value into this array on the right-hand-side.
80pub fn mul_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
81    numeric(
82        lhs,
83        &ConstantArray::new(rhs, lhs.len()).into_array(),
84        NumericOperator::Mul,
85    )
86}
87
88/// Point-wise divide two numeric arrays.
89pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
90    numeric(lhs, rhs, NumericOperator::Div)
91}
92
93/// Point-wise divide a scalar value into this array on the right-hand-side.
94pub fn div_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
95    numeric(
96        lhs,
97        &ConstantArray::new(rhs, lhs.len()).into_array(),
98        NumericOperator::Mul,
99    )
100}
101
102/// Point-wise numeric operation between two arrays of the same type and length.
103pub fn numeric(lhs: &dyn Array, rhs: &dyn Array, op: NumericOperator) -> VortexResult<ArrayRef> {
104    NUMERIC_FN
105        .invoke(&InvocationArgs {
106            inputs: &[lhs.into(), rhs.into()],
107            options: &op,
108        })?
109        .unwrap_array()
110}
111
112pub struct NumericKernelRef(ArcRef<dyn Kernel>);
113inventory::collect!(NumericKernelRef);
114
115pub trait NumericKernel: VTable {
116    fn numeric(
117        &self,
118        array: &Self::Array,
119        other: &dyn Array,
120        op: NumericOperator,
121    ) -> VortexResult<Option<ArrayRef>>;
122}
123
124#[derive(Debug)]
125pub struct NumericKernelAdapter<V: VTable>(pub V);
126
127impl<V: VTable + NumericKernel> NumericKernelAdapter<V> {
128    pub const fn lift(&'static self) -> NumericKernelRef {
129        NumericKernelRef(ArcRef::new_ref(self))
130    }
131}
132
133impl<V: VTable + NumericKernel> Kernel for NumericKernelAdapter<V> {
134    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
135        let inputs = NumericArgs::try_from(args)?;
136        let Some(lhs) = inputs.lhs.as_opt::<V>() else {
137            return Ok(None);
138        };
139        Ok(V::numeric(&self.0, lhs, inputs.rhs, inputs.operator)?.map(|array| array.into()))
140    }
141}
142
143struct Numeric;
144
145impl ComputeFnVTable for Numeric {
146    fn invoke(
147        &self,
148        args: &InvocationArgs,
149        kernels: &[ArcRef<dyn Kernel>],
150    ) -> VortexResult<Output> {
151        let NumericArgs { lhs, rhs, operator } = NumericArgs::try_from(args)?;
152
153        for kernel in kernels {
154            if let Some(output) = kernel.invoke(args)? {
155                return Ok(output);
156            }
157        }
158
159        // Check if LHS supports the operation directly.
160        if let Some(output) = lhs.invoke(&NUMERIC_FN, args)? {
161            return Ok(output);
162        }
163
164        // Check if RHS supports the operation directly.
165        let inverted_args = InvocationArgs {
166            inputs: &[rhs.into(), lhs.into()],
167            options: &operator.swap(),
168        };
169        for kernel in kernels {
170            if let Some(output) = kernel.invoke(&inverted_args)? {
171                return Ok(output);
172            }
173        }
174        if let Some(output) = rhs.invoke(&NUMERIC_FN, &inverted_args)? {
175            return Ok(output);
176        }
177
178        log::debug!(
179            "No numeric implementation found for LHS {}, RHS {}, and operator {:?}",
180            lhs.encoding_id(),
181            rhs.encoding_id(),
182            operator,
183        );
184
185        // If neither side implements the trait, then we delegate to Arrow compute.
186        Ok(arrow_numeric(lhs, rhs, operator)?.into())
187    }
188
189    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
190        let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
191        if !matches!(
192            (lhs.dtype(), rhs.dtype()),
193            (DType::Primitive(..), DType::Primitive(..)) | (DType::Decimal(..), DType::Decimal(..))
194        ) || !lhs.dtype().eq_ignore_nullability(rhs.dtype())
195        {
196            vortex_bail!(
197                "Numeric operations are only supported on two arrays sharing the same numeric type: {} {}",
198                lhs.dtype(),
199                rhs.dtype()
200            )
201        }
202        Ok(lhs.dtype().union_nullability(rhs.dtype().nullability()))
203    }
204
205    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
206        let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
207        if lhs.len() != rhs.len() {
208            vortex_bail!(
209                "Numeric operations aren't supported on arrays of different lengths {} {}",
210                lhs.len(),
211                rhs.len()
212            )
213        }
214        Ok(lhs.len())
215    }
216
217    fn is_elementwise(&self) -> bool {
218        true
219    }
220}
221
222struct NumericArgs<'a> {
223    lhs: &'a dyn Array,
224    rhs: &'a dyn Array,
225    operator: NumericOperator,
226}
227
228impl<'a> TryFrom<&InvocationArgs<'a>> for NumericArgs<'a> {
229    type Error = VortexError;
230
231    fn try_from(args: &InvocationArgs<'a>) -> VortexResult<Self> {
232        if args.inputs.len() != 2 {
233            vortex_bail!("Numeric operations require exactly 2 inputs");
234        }
235        let lhs = args.inputs[0]
236            .array()
237            .ok_or_else(|| vortex_err!("LHS is not an array"))?;
238        let rhs = args.inputs[1]
239            .array()
240            .ok_or_else(|| vortex_err!("RHS is not an array"))?;
241        let operator = *args
242            .options
243            .as_any()
244            .downcast_ref::<NumericOperator>()
245            .ok_or_else(|| vortex_err!("Operator is not a numeric operator"))?;
246        Ok(Self { lhs, rhs, operator })
247    }
248}
249
250impl Options for NumericOperator {
251    fn as_any(&self) -> &dyn Any {
252        self
253    }
254}
255
256/// Implementation of `BinaryNumericFn` using the Arrow crate.
257///
258/// Note that other encodings should handle a constant RHS value, so we can assume here that
259/// the RHS is not constant and expand to a full array.
260fn arrow_numeric(
261    lhs: &dyn Array,
262    rhs: &dyn Array,
263    operator: NumericOperator,
264) -> VortexResult<ArrayRef> {
265    let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
266    let len = lhs.len();
267
268    let left = Datum::try_new(lhs)?;
269    let right = Datum::try_new_with_target_datatype(rhs, left.data_type())?;
270
271    let array = match operator {
272        NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
273        NumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
274        NumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?,
275        NumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
276        NumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
277        NumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
278    };
279
280    Ok(from_arrow_array_with_len(array.as_ref(), len, nullable))
281}
282
283#[cfg(test)]
284mod test {
285    use vortex_buffer::buffer;
286    use vortex_scalar::Scalar;
287
288    use crate::IntoArray;
289    use crate::arrays::PrimitiveArray;
290    use crate::canonical::ToCanonical;
291    use crate::compute::sub_scalar;
292
293    #[test]
294    fn test_scalar_subtract_unsigned() {
295        let values = buffer![1u16, 2, 3].into_array();
296        let results = sub_scalar(&values, 1u16.into())
297            .unwrap()
298            .to_primitive()
299            .as_slice::<u16>()
300            .to_vec();
301        assert_eq!(results, &[0u16, 1, 2]);
302    }
303
304    #[test]
305    fn test_scalar_subtract_signed() {
306        let values = buffer![1i64, 2, 3].into_array();
307        let results = sub_scalar(&values, (-1i64).into())
308            .unwrap()
309            .to_primitive()
310            .as_slice::<i64>()
311            .to_vec();
312        assert_eq!(results, &[2i64, 3, 4]);
313    }
314
315    #[test]
316    fn test_scalar_subtract_nullable() {
317        let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
318        let result = sub_scalar(values.as_ref(), Some(1u16).into())
319            .unwrap()
320            .to_primitive();
321
322        let actual = (0..result.len())
323            .map(|index| result.scalar_at(index))
324            .collect::<Vec<_>>();
325        assert_eq!(
326            actual,
327            vec![
328                Scalar::from(Some(0u16)),
329                Scalar::from(Some(1u16)),
330                Scalar::from(None::<u16>),
331                Scalar::from(Some(2u16))
332            ]
333        );
334    }
335
336    #[test]
337    fn test_scalar_subtract_float() {
338        let values = buffer![1.0f64, 2.0, 3.0].into_array();
339        let to_subtract = -1f64;
340        let results = sub_scalar(&values, to_subtract.into())
341            .unwrap()
342            .to_primitive()
343            .as_slice::<f64>()
344            .to_vec();
345        assert_eq!(results, &[2.0f64, 3.0, 4.0]);
346    }
347
348    #[test]
349    fn test_scalar_subtract_float_underflow_is_ok() {
350        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
351        let _results = sub_scalar(&values, 1.0f32.into()).unwrap();
352        let _results = sub_scalar(&values, f32::MAX.into()).unwrap();
353    }
354
355    #[test]
356    fn test_scalar_subtract_type_mismatch_fails() {
357        let values = buffer![1u64, 2, 3].into_array();
358        // Subtracting incompatible dtypes should fail
359        let _results =
360            sub_scalar(&values, 1.5f64.into()).expect_err("Expected type mismatch error");
361    }
362}