Skip to main content

vortex_array/compute/
numeric.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5
6use vortex_error::VortexResult;
7
8use crate::Array;
9use crate::ArrayRef;
10use crate::IntoArray;
11use crate::arrays::ConstantArray;
12use crate::arrow::Datum;
13use crate::arrow::from_arrow_array_with_len;
14use crate::compute::Options;
15use crate::scalar::NumericOperator;
16use crate::scalar::Scalar;
17
18/// Point-wise add two numeric arrays.
19///
20/// Errs at runtime if the sum would overflow or underflow.
21///
22/// The result is null at any index that either input is null.
23pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
24    numeric(lhs, rhs, NumericOperator::Add)
25}
26
27/// Point-wise add a scalar value to this array on the right-hand-side.
28pub fn add_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
29    numeric(
30        lhs,
31        &ConstantArray::new(rhs, lhs.len()).into_array(),
32        NumericOperator::Add,
33    )
34}
35
36/// Point-wise subtract two numeric arrays.
37pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
38    numeric(lhs, rhs, NumericOperator::Sub)
39}
40
41/// Point-wise subtract a scalar value from this array on the right-hand-side.
42pub fn sub_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
43    numeric(
44        lhs,
45        &ConstantArray::new(rhs, lhs.len()).into_array(),
46        NumericOperator::Sub,
47    )
48}
49
50/// Point-wise multiply two numeric arrays.
51pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
52    numeric(lhs, rhs, NumericOperator::Mul)
53}
54
55/// Point-wise multiply a scalar value into this array on the right-hand-side.
56pub fn mul_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
57    numeric(
58        lhs,
59        &ConstantArray::new(rhs, lhs.len()).into_array(),
60        NumericOperator::Mul,
61    )
62}
63
64/// Point-wise divide two numeric arrays.
65pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
66    numeric(lhs, rhs, NumericOperator::Div)
67}
68
69/// Point-wise divide a scalar value into this array on the right-hand-side.
70pub fn div_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
71    numeric(
72        lhs,
73        &ConstantArray::new(rhs, lhs.len()).into_array(),
74        NumericOperator::Mul,
75    )
76}
77
78/// Point-wise numeric operation between two arrays of the same type and length.
79pub fn numeric(lhs: &dyn Array, rhs: &dyn Array, op: NumericOperator) -> VortexResult<ArrayRef> {
80    arrow_numeric(lhs, rhs, op)
81}
82
83impl Options for NumericOperator {
84    fn as_any(&self) -> &dyn Any {
85        self
86    }
87}
88
89/// Implementation of numeric operations using the Arrow crate.
90pub(crate) fn arrow_numeric(
91    lhs: &dyn Array,
92    rhs: &dyn Array,
93    operator: NumericOperator,
94) -> VortexResult<ArrayRef> {
95    let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
96    let len = lhs.len();
97
98    let left = Datum::try_new(lhs)?;
99    let right = Datum::try_new_with_target_datatype(rhs, left.data_type())?;
100
101    let array = match operator {
102        NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
103        NumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
104        NumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?,
105        NumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
106        NumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
107        NumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
108    };
109
110    from_arrow_array_with_len(array.as_ref(), len, nullable)
111}
112
113#[cfg(test)]
114mod test {
115    use vortex_buffer::buffer;
116
117    use crate::IntoArray;
118    use crate::arrays::PrimitiveArray;
119    use crate::assert_arrays_eq;
120    use crate::compute::sub_scalar;
121
122    #[test]
123    fn test_scalar_subtract_unsigned() {
124        let values = buffer![1u16, 2, 3].into_array();
125        let result = sub_scalar(&values, 1u16.into()).unwrap();
126        assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]));
127    }
128
129    #[test]
130    fn test_scalar_subtract_signed() {
131        let values = buffer![1i64, 2, 3].into_array();
132        let result = sub_scalar(&values, (-1i64).into()).unwrap();
133        assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]));
134    }
135
136    #[test]
137    fn test_scalar_subtract_nullable() {
138        let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
139        let result = sub_scalar(values.as_ref(), Some(1u16).into()).unwrap();
140        assert_arrays_eq!(
141            result,
142            PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)])
143        );
144    }
145
146    #[test]
147    fn test_scalar_subtract_float() {
148        let values = buffer![1.0f64, 2.0, 3.0].into_array();
149        let to_subtract = -1f64;
150        let result = sub_scalar(&values, to_subtract.into()).unwrap();
151        assert_arrays_eq!(result, PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]));
152    }
153
154    #[test]
155    fn test_scalar_subtract_float_underflow_is_ok() {
156        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
157        let _results = sub_scalar(&values, 1.0f32.into()).unwrap();
158        let _results = sub_scalar(&values, f32::MAX.into()).unwrap();
159    }
160}