vortex_array/compute/
numeric.rs

1use std::any::Any;
2use std::sync::LazyLock;
3
4use vortex_dtype::DType;
5use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
6use vortex_scalar::{NumericOperator, Scalar};
7
8use crate::arcref::ArcRef;
9use crate::arrays::ConstantArray;
10use crate::arrow::{Datum, from_arrow_array_with_len};
11use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
12use crate::encoding::Encoding;
13use crate::{Array, ArrayRef};
14
15/// Point-wise add two numeric arrays.
16pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
17    numeric(lhs, rhs, NumericOperator::Add)
18}
19
20/// Point-wise add a scalar value to this array on the right-hand-side.
21pub fn add_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
22    numeric(
23        lhs,
24        &ConstantArray::new(rhs, lhs.len()).into_array(),
25        NumericOperator::Add,
26    )
27}
28
29/// Point-wise subtract two numeric arrays.
30pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
31    numeric(lhs, rhs, NumericOperator::Sub)
32}
33
34/// Point-wise subtract a scalar value from this array on the right-hand-side.
35pub fn sub_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
36    numeric(
37        lhs,
38        &ConstantArray::new(rhs, lhs.len()).into_array(),
39        NumericOperator::Sub,
40    )
41}
42
43/// Point-wise multiply two numeric arrays.
44pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
45    numeric(lhs, rhs, NumericOperator::Mul)
46}
47
48/// Point-wise multiply a scalar value into this array on the right-hand-side.
49pub fn mul_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
50    numeric(
51        lhs,
52        &ConstantArray::new(rhs, lhs.len()).into_array(),
53        NumericOperator::Mul,
54    )
55}
56
57/// Point-wise divide two numeric arrays.
58pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
59    numeric(lhs, rhs, NumericOperator::Div)
60}
61
62/// Point-wise divide a scalar value into this array on the right-hand-side.
63pub fn div_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
64    numeric(
65        lhs,
66        &ConstantArray::new(rhs, lhs.len()).into_array(),
67        NumericOperator::Mul,
68    )
69}
70
71/// Point-wise numeric operation between two arrays of the same type and length.
72pub fn numeric(lhs: &dyn Array, rhs: &dyn Array, op: NumericOperator) -> VortexResult<ArrayRef> {
73    NUMERIC_FN
74        .invoke(&InvocationArgs {
75            inputs: &[lhs.into(), rhs.into()],
76            options: &op,
77        })?
78        .unwrap_array()
79}
80
81pub struct NumericKernelRef(ArcRef<dyn Kernel>);
82inventory::collect!(NumericKernelRef);
83
84pub trait NumericKernel: Encoding {
85    fn numeric(
86        &self,
87        array: &Self::Array,
88        other: &dyn Array,
89        op: NumericOperator,
90    ) -> VortexResult<Option<ArrayRef>>;
91}
92
93#[derive(Debug)]
94pub struct NumericKernelAdapter<E: Encoding>(pub E);
95
96impl<E: Encoding + NumericKernel> NumericKernelAdapter<E> {
97    pub const fn lift(&'static self) -> NumericKernelRef {
98        NumericKernelRef(ArcRef::new_ref(self))
99    }
100}
101
102impl<E: Encoding + NumericKernel> Kernel for NumericKernelAdapter<E> {
103    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
104        let inputs = NumericArgs::try_from(args)?;
105        let Some(lhs) = inputs.lhs.as_any().downcast_ref::<E::Array>() else {
106            return Ok(None);
107        };
108        Ok(E::numeric(&self.0, lhs, inputs.rhs, inputs.operator)?.map(|array| array.into()))
109    }
110}
111
112pub static NUMERIC_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
113    let compute = ComputeFn::new("numeric".into(), ArcRef::new_ref(&Numeric));
114    for kernel in inventory::iter::<NumericKernelRef> {
115        compute.register_kernel(kernel.0.clone());
116    }
117    compute
118});
119
120struct Numeric;
121
122impl ComputeFnVTable for Numeric {
123    fn invoke(
124        &self,
125        args: &InvocationArgs,
126        kernels: &[ArcRef<dyn Kernel>],
127    ) -> VortexResult<Output> {
128        let NumericArgs { lhs, rhs, operator } = NumericArgs::try_from(args)?;
129
130        // Check if LHS supports the operation directly.
131        for kernel in kernels {
132            if let Some(output) = kernel.invoke(args)? {
133                return Ok(output);
134            }
135        }
136        if let Some(output) = lhs.invoke(&NUMERIC_FN, args)? {
137            return Ok(output);
138        }
139
140        // Check if RHS supports the operation directly.
141        let inverted_args = InvocationArgs {
142            inputs: &[rhs.into(), lhs.into()],
143            options: &operator.swap(),
144        };
145        for kernel in kernels {
146            if let Some(output) = kernel.invoke(&inverted_args)? {
147                return Ok(output);
148            }
149        }
150        if let Some(output) = rhs.invoke(&NUMERIC_FN, &inverted_args)? {
151            return Ok(output);
152        }
153
154        log::debug!(
155            "No numeric implementation found for LHS {}, RHS {}, and operator {:?}",
156            lhs.encoding(),
157            rhs.encoding(),
158            operator,
159        );
160
161        // If neither side implements the trait, then we delegate to Arrow compute.
162        Ok(arrow_numeric(lhs, rhs, operator)?.into())
163    }
164
165    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
166        let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
167        if !matches!(
168            (lhs.dtype(), rhs.dtype()),
169            (DType::Primitive(..), DType::Primitive(..)) | (DType::Decimal(..), DType::Decimal(..))
170        ) || !lhs.dtype().eq_ignore_nullability(rhs.dtype())
171        {
172            vortex_bail!(
173                "Numeric operations are only supported on two arrays sharing the same numeric type: {} {}",
174                lhs.dtype(),
175                rhs.dtype()
176            )
177        }
178        Ok(lhs
179            .dtype()
180            .with_nullability((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()))
181    }
182
183    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
184        let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
185        if lhs.len() != rhs.len() {
186            vortex_bail!(
187                "Numeric operations aren't supported on arrays of different lengths {} {}",
188                lhs.len(),
189                rhs.len()
190            )
191        }
192        Ok(lhs.len())
193    }
194
195    fn is_elementwise(&self) -> bool {
196        true
197    }
198}
199
200struct NumericArgs<'a> {
201    lhs: &'a dyn Array,
202    rhs: &'a dyn Array,
203    operator: NumericOperator,
204}
205
206impl<'a> TryFrom<&InvocationArgs<'a>> for NumericArgs<'a> {
207    type Error = VortexError;
208
209    fn try_from(args: &InvocationArgs<'a>) -> VortexResult<Self> {
210        if args.inputs.len() != 2 {
211            vortex_bail!("Numeric operations require exactly 2 inputs");
212        }
213        let lhs = args.inputs[0]
214            .array()
215            .ok_or_else(|| vortex_err!("LHS is not an array"))?;
216        let rhs = args.inputs[1]
217            .array()
218            .ok_or_else(|| vortex_err!("RHS is not an array"))?;
219        let operator = *args
220            .options
221            .as_any()
222            .downcast_ref::<NumericOperator>()
223            .ok_or_else(|| vortex_err!("Operator is not a numeric operator"))?;
224        Ok(Self { lhs, rhs, operator })
225    }
226}
227
228impl Options for NumericOperator {
229    fn as_any(&self) -> &dyn Any {
230        self
231    }
232}
233
234/// Implementation of `BinaryNumericFn` using the Arrow crate.
235///
236/// Note that other encodings should handle a constant RHS value, so we can assume here that
237/// the RHS is not constant and expand to a full array.
238fn arrow_numeric(
239    lhs: &dyn Array,
240    rhs: &dyn Array,
241    operator: NumericOperator,
242) -> VortexResult<ArrayRef> {
243    let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
244    let len = lhs.len();
245
246    let left = Datum::try_new(lhs)?;
247    let right = Datum::try_new(rhs)?;
248
249    let array = match operator {
250        NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
251        NumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
252        NumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?,
253        NumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
254        NumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
255        NumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
256    };
257
258    from_arrow_array_with_len(array, len, nullable)
259}
260
261#[cfg(test)]
262mod test {
263    use vortex_buffer::buffer;
264    use vortex_scalar::Scalar;
265
266    use crate::IntoArray;
267    use crate::array::Array;
268    use crate::arrays::PrimitiveArray;
269    use crate::canonical::ToCanonical;
270    use crate::compute::{scalar_at, sub_scalar};
271
272    #[test]
273    fn test_scalar_subtract_unsigned() {
274        let values = buffer![1u16, 2, 3].into_array();
275        let results = sub_scalar(&values, 1u16.into())
276            .unwrap()
277            .to_primitive()
278            .unwrap()
279            .as_slice::<u16>()
280            .to_vec();
281        assert_eq!(results, &[0u16, 1, 2]);
282    }
283
284    #[test]
285    fn test_scalar_subtract_signed() {
286        let values = buffer![1i64, 2, 3].into_array();
287        let results = sub_scalar(&values, (-1i64).into())
288            .unwrap()
289            .to_primitive()
290            .unwrap()
291            .as_slice::<i64>()
292            .to_vec();
293        assert_eq!(results, &[2i64, 3, 4]);
294    }
295
296    #[test]
297    fn test_scalar_subtract_nullable() {
298        let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
299        let result = sub_scalar(&values, Some(1u16).into())
300            .unwrap()
301            .to_primitive()
302            .unwrap();
303
304        let actual = (0..result.len())
305            .map(|index| scalar_at(&result, index).unwrap())
306            .collect::<Vec<_>>();
307        assert_eq!(
308            actual,
309            vec![
310                Scalar::from(Some(0u16)),
311                Scalar::from(Some(1u16)),
312                Scalar::from(None::<u16>),
313                Scalar::from(Some(2u16))
314            ]
315        );
316    }
317
318    #[test]
319    fn test_scalar_subtract_float() {
320        let values = buffer![1.0f64, 2.0, 3.0].into_array();
321        let to_subtract = -1f64;
322        let results = sub_scalar(&values, to_subtract.into())
323            .unwrap()
324            .to_primitive()
325            .unwrap()
326            .as_slice::<f64>()
327            .to_vec();
328        assert_eq!(results, &[2.0f64, 3.0, 4.0]);
329    }
330
331    #[test]
332    fn test_scalar_subtract_float_underflow_is_ok() {
333        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
334        let _results = sub_scalar(&values, 1.0f32.into()).unwrap();
335        let _results = sub_scalar(&values, f32::MAX.into()).unwrap();
336    }
337
338    #[test]
339    fn test_scalar_subtract_type_mismatch_fails() {
340        let values = buffer![1u64, 2, 3].into_array();
341        // Subtracting incompatible dtypes should fail
342        let _results =
343            sub_scalar(&values, 1.5f64.into()).expect_err("Expected type mismatch error");
344    }
345}