vortex_array/compute/
numeric.rs

1use std::any::Any;
2use std::sync::LazyLock;
3
4use arcref::ArcRef;
5use vortex_dtype::DType;
6use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
7use vortex_scalar::{NumericOperator, Scalar};
8
9use crate::arrays::ConstantArray;
10use crate::arrow::{Datum, from_arrow_array_with_len};
11use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
12use crate::vtable::VTable;
13use crate::{Array, ArrayRef, IntoArray};
14
15/// Point-wise add two numeric arrays.
16///
17/// Errs at runtime if the sum would overflow or underflow.
18///
19/// The result is null at any index that either input is null.
20pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
21    numeric(lhs, rhs, NumericOperator::Add)
22}
23
24/// Point-wise add a scalar value to this array on the right-hand-side.
25pub fn add_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
26    numeric(
27        lhs,
28        &ConstantArray::new(rhs, lhs.len()).into_array(),
29        NumericOperator::Add,
30    )
31}
32
33/// Point-wise subtract two numeric arrays.
34pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
35    numeric(lhs, rhs, NumericOperator::Sub)
36}
37
38/// Point-wise subtract a scalar value from this array on the right-hand-side.
39pub fn sub_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
40    numeric(
41        lhs,
42        &ConstantArray::new(rhs, lhs.len()).into_array(),
43        NumericOperator::Sub,
44    )
45}
46
47/// Point-wise multiply two numeric arrays.
48pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
49    numeric(lhs, rhs, NumericOperator::Mul)
50}
51
52/// Point-wise multiply a scalar value into this array on the right-hand-side.
53pub fn mul_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
54    numeric(
55        lhs,
56        &ConstantArray::new(rhs, lhs.len()).into_array(),
57        NumericOperator::Mul,
58    )
59}
60
61/// Point-wise divide two numeric arrays.
62pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
63    numeric(lhs, rhs, NumericOperator::Div)
64}
65
66/// Point-wise divide a scalar value into this array on the right-hand-side.
67pub fn div_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
68    numeric(
69        lhs,
70        &ConstantArray::new(rhs, lhs.len()).into_array(),
71        NumericOperator::Mul,
72    )
73}
74
75/// Point-wise numeric operation between two arrays of the same type and length.
76pub fn numeric(lhs: &dyn Array, rhs: &dyn Array, op: NumericOperator) -> VortexResult<ArrayRef> {
77    NUMERIC_FN
78        .invoke(&InvocationArgs {
79            inputs: &[lhs.into(), rhs.into()],
80            options: &op,
81        })?
82        .unwrap_array()
83}
84
85pub struct NumericKernelRef(ArcRef<dyn Kernel>);
86inventory::collect!(NumericKernelRef);
87
88pub trait NumericKernel: VTable {
89    fn numeric(
90        &self,
91        array: &Self::Array,
92        other: &dyn Array,
93        op: NumericOperator,
94    ) -> VortexResult<Option<ArrayRef>>;
95}
96
97#[derive(Debug)]
98pub struct NumericKernelAdapter<V: VTable>(pub V);
99
100impl<V: VTable + NumericKernel> NumericKernelAdapter<V> {
101    pub const fn lift(&'static self) -> NumericKernelRef {
102        NumericKernelRef(ArcRef::new_ref(self))
103    }
104}
105
106impl<V: VTable + NumericKernel> Kernel for NumericKernelAdapter<V> {
107    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
108        let inputs = NumericArgs::try_from(args)?;
109        let Some(lhs) = inputs.lhs.as_opt::<V>() else {
110            return Ok(None);
111        };
112        Ok(V::numeric(&self.0, lhs, inputs.rhs, inputs.operator)?.map(|array| array.into()))
113    }
114}
115
116pub static NUMERIC_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
117    let compute = ComputeFn::new("numeric".into(), ArcRef::new_ref(&Numeric));
118    for kernel in inventory::iter::<NumericKernelRef> {
119        compute.register_kernel(kernel.0.clone());
120    }
121    compute
122});
123
124struct Numeric;
125
126impl ComputeFnVTable for Numeric {
127    fn invoke(
128        &self,
129        args: &InvocationArgs,
130        kernels: &[ArcRef<dyn Kernel>],
131    ) -> VortexResult<Output> {
132        let NumericArgs { lhs, rhs, operator } = NumericArgs::try_from(args)?;
133
134        // Check if LHS supports the operation directly.
135        for kernel in kernels {
136            if let Some(output) = kernel.invoke(args)? {
137                return Ok(output);
138            }
139        }
140        if let Some(output) = lhs.invoke(&NUMERIC_FN, args)? {
141            return Ok(output);
142        }
143
144        // Check if RHS supports the operation directly.
145        let inverted_args = InvocationArgs {
146            inputs: &[rhs.into(), lhs.into()],
147            options: &operator.swap(),
148        };
149        for kernel in kernels {
150            if let Some(output) = kernel.invoke(&inverted_args)? {
151                return Ok(output);
152            }
153        }
154        if let Some(output) = rhs.invoke(&NUMERIC_FN, &inverted_args)? {
155            return Ok(output);
156        }
157
158        log::debug!(
159            "No numeric implementation found for LHS {}, RHS {}, and operator {:?}",
160            lhs.encoding_id(),
161            rhs.encoding_id(),
162            operator,
163        );
164
165        // If neither side implements the trait, then we delegate to Arrow compute.
166        Ok(arrow_numeric(lhs, rhs, operator)?.into())
167    }
168
169    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
170        let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
171        if !matches!(
172            (lhs.dtype(), rhs.dtype()),
173            (DType::Primitive(..), DType::Primitive(..)) | (DType::Decimal(..), DType::Decimal(..))
174        ) || !lhs.dtype().eq_ignore_nullability(rhs.dtype())
175        {
176            vortex_bail!(
177                "Numeric operations are only supported on two arrays sharing the same numeric type: {} {}",
178                lhs.dtype(),
179                rhs.dtype()
180            )
181        }
182        Ok(lhs.dtype().union_nullability(rhs.dtype().nullability()))
183    }
184
185    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
186        let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
187        if lhs.len() != rhs.len() {
188            vortex_bail!(
189                "Numeric operations aren't supported on arrays of different lengths {} {}",
190                lhs.len(),
191                rhs.len()
192            )
193        }
194        Ok(lhs.len())
195    }
196
197    fn is_elementwise(&self) -> bool {
198        true
199    }
200}
201
202struct NumericArgs<'a> {
203    lhs: &'a dyn Array,
204    rhs: &'a dyn Array,
205    operator: NumericOperator,
206}
207
208impl<'a> TryFrom<&InvocationArgs<'a>> for NumericArgs<'a> {
209    type Error = VortexError;
210
211    fn try_from(args: &InvocationArgs<'a>) -> VortexResult<Self> {
212        if args.inputs.len() != 2 {
213            vortex_bail!("Numeric operations require exactly 2 inputs");
214        }
215        let lhs = args.inputs[0]
216            .array()
217            .ok_or_else(|| vortex_err!("LHS is not an array"))?;
218        let rhs = args.inputs[1]
219            .array()
220            .ok_or_else(|| vortex_err!("RHS is not an array"))?;
221        let operator = *args
222            .options
223            .as_any()
224            .downcast_ref::<NumericOperator>()
225            .ok_or_else(|| vortex_err!("Operator is not a numeric operator"))?;
226        Ok(Self { lhs, rhs, operator })
227    }
228}
229
230impl Options for NumericOperator {
231    fn as_any(&self) -> &dyn Any {
232        self
233    }
234}
235
236/// Implementation of `BinaryNumericFn` using the Arrow crate.
237///
238/// Note that other encodings should handle a constant RHS value, so we can assume here that
239/// the RHS is not constant and expand to a full array.
240fn arrow_numeric(
241    lhs: &dyn Array,
242    rhs: &dyn Array,
243    operator: NumericOperator,
244) -> VortexResult<ArrayRef> {
245    let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
246    let len = lhs.len();
247
248    let left = Datum::try_new(lhs)?;
249    let right = Datum::try_new(rhs)?;
250
251    let array = match operator {
252        NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
253        NumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
254        NumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?,
255        NumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
256        NumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
257        NumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
258    };
259
260    from_arrow_array_with_len(array.as_ref(), len, nullable)
261}
262
263#[cfg(test)]
264mod test {
265    use vortex_buffer::buffer;
266    use vortex_scalar::Scalar;
267
268    use crate::IntoArray;
269    use crate::arrays::PrimitiveArray;
270    use crate::canonical::ToCanonical;
271    use crate::compute::sub_scalar;
272
273    #[test]
274    fn test_scalar_subtract_unsigned() {
275        let values = buffer![1u16, 2, 3].into_array();
276        let results = sub_scalar(&values, 1u16.into())
277            .unwrap()
278            .to_primitive()
279            .unwrap()
280            .as_slice::<u16>()
281            .to_vec();
282        assert_eq!(results, &[0u16, 1, 2]);
283    }
284
285    #[test]
286    fn test_scalar_subtract_signed() {
287        let values = buffer![1i64, 2, 3].into_array();
288        let results = sub_scalar(&values, (-1i64).into())
289            .unwrap()
290            .to_primitive()
291            .unwrap()
292            .as_slice::<i64>()
293            .to_vec();
294        assert_eq!(results, &[2i64, 3, 4]);
295    }
296
297    #[test]
298    fn test_scalar_subtract_nullable() {
299        let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
300        let result = sub_scalar(values.as_ref(), Some(1u16).into())
301            .unwrap()
302            .to_primitive()
303            .unwrap();
304
305        let actual = (0..result.len())
306            .map(|index| result.scalar_at(index).unwrap())
307            .collect::<Vec<_>>();
308        assert_eq!(
309            actual,
310            vec![
311                Scalar::from(Some(0u16)),
312                Scalar::from(Some(1u16)),
313                Scalar::from(None::<u16>),
314                Scalar::from(Some(2u16))
315            ]
316        );
317    }
318
319    #[test]
320    fn test_scalar_subtract_float() {
321        let values = buffer![1.0f64, 2.0, 3.0].into_array();
322        let to_subtract = -1f64;
323        let results = sub_scalar(&values, to_subtract.into())
324            .unwrap()
325            .to_primitive()
326            .unwrap()
327            .as_slice::<f64>()
328            .to_vec();
329        assert_eq!(results, &[2.0f64, 3.0, 4.0]);
330    }
331
332    #[test]
333    fn test_scalar_subtract_float_underflow_is_ok() {
334        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
335        let _results = sub_scalar(&values, 1.0f32.into()).unwrap();
336        let _results = sub_scalar(&values, f32::MAX.into()).unwrap();
337    }
338
339    #[test]
340    fn test_scalar_subtract_type_mismatch_fails() {
341        let values = buffer![1u64, 2, 3].into_array();
342        // Subtracting incompatible dtypes should fail
343        let _results =
344            sub_scalar(&values, 1.5f64.into()).expect_err("Expected type mismatch error");
345    }
346}