vortex_array/compute/arrays/
arithmetic.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::{Hash, Hasher};
5use std::sync::LazyLock;
6
7use enum_map::{Enum, EnumMap, enum_map};
8use vortex_buffer::ByteBuffer;
9use vortex_compute::arithmetic::{
10    Add, Arithmetic, CheckedArithmetic, CheckedOperator, Div, Mul, Operator, Sub,
11};
12use vortex_dtype::{DType, NativePType, PTypeDowncastExt, match_each_native_ptype};
13use vortex_error::{VortexExpect, VortexResult, vortex_err};
14use vortex_scalar::{PValue, Scalar};
15use vortex_vector::primitive::PVector;
16
17use crate::arrays::ConstantArray;
18use crate::execution::{BatchKernelRef, BindCtx, kernel};
19use crate::serde::ArrayChildren;
20use crate::stats::{ArrayStats, StatsSetRef};
21use crate::vtable::{
22    ArrayVTable, NotSupported, OperatorVTable, SerdeVTable, VTable, VisitorVTable,
23};
24use crate::{
25    Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayEq, ArrayHash, ArrayRef,
26    DeserializeMetadata, EmptyMetadata, EncodingId, EncodingRef, IntoArray, Precision, vtable,
27};
28
29/// The set of operators supported by an arithmetic array.
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Enum)]
31pub enum ArithmeticOperator {
32    /// Addition - errors on overflow for integers.
33    Add,
34    /// Subtraction - errors on overflow for integers.
35    Sub,
36    /// Multiplication - errors on overflow for integers.
37    Mul,
38    /// Division - errors on division by zero for integers.
39    Div,
40}
41
42vtable!(Arithmetic);
43
44#[derive(Debug, Clone)]
45pub struct ArithmeticArray {
46    encoding: EncodingRef,
47    lhs: ArrayRef,
48    rhs: ArrayRef,
49    stats: ArrayStats,
50}
51
52impl ArithmeticArray {
53    /// Create a new arithmetic array.
54    pub fn new(lhs: ArrayRef, rhs: ArrayRef, operator: ArithmeticOperator) -> Self {
55        assert_eq!(
56            lhs.len(),
57            rhs.len(),
58            "Arithmetic arrays require lhs and rhs to have the same length"
59        );
60
61        // TODO(ngates): should we automatically cast non-null to nullable if required?
62        assert!(matches!(lhs.dtype(), DType::Primitive(..)));
63        assert_eq!(lhs.dtype(), rhs.dtype());
64
65        Self {
66            encoding: ENCODINGS[operator].clone(),
67            lhs,
68            rhs,
69            stats: ArrayStats::default(),
70        }
71    }
72
73    /// Returns the operator of this logical array.
74    pub fn operator(&self) -> ArithmeticOperator {
75        self.encoding.as_::<ArithmeticVTable>().operator
76    }
77}
78
79#[derive(Debug, Clone)]
80pub struct ArithmeticEncoding {
81    // We include the operator in the encoding so each operator is a different encoding ID.
82    // This makes it easier for plugins to construct expressions and perform pushdown
83    // optimizations.
84    operator: ArithmeticOperator,
85}
86
87#[allow(clippy::mem_forget)]
88static ENCODINGS: LazyLock<EnumMap<ArithmeticOperator, EncodingRef>> = LazyLock::new(|| {
89    enum_map! {
90        operator => ArithmeticEncoding { operator }.to_encoding(),
91    }
92});
93
94impl VTable for ArithmeticVTable {
95    type Array = ArithmeticArray;
96    type Encoding = ArithmeticEncoding;
97    type ArrayVTable = Self;
98    type CanonicalVTable = NotSupported;
99    type OperationsVTable = NotSupported;
100    type ValidityVTable = NotSupported;
101    type VisitorVTable = Self;
102    type ComputeVTable = NotSupported;
103    type EncodeVTable = NotSupported;
104    type SerdeVTable = Self;
105    type OperatorVTable = Self;
106
107    fn id(encoding: &Self::Encoding) -> EncodingId {
108        match encoding.operator {
109            ArithmeticOperator::Add => EncodingId::from("vortex.add"),
110            ArithmeticOperator::Sub => EncodingId::from("vortex.sub"),
111            ArithmeticOperator::Mul => EncodingId::from("vortex.mul"),
112            ArithmeticOperator::Div => EncodingId::from("vortex.div"),
113        }
114    }
115
116    fn encoding(array: &Self::Array) -> EncodingRef {
117        array.encoding.clone()
118    }
119}
120
121impl ArrayVTable<ArithmeticVTable> for ArithmeticVTable {
122    fn len(array: &ArithmeticArray) -> usize {
123        array.lhs.len()
124    }
125
126    fn dtype(array: &ArithmeticArray) -> &DType {
127        array.lhs.dtype()
128    }
129
130    fn stats(array: &ArithmeticArray) -> StatsSetRef<'_> {
131        array.stats.to_ref(array.as_ref())
132    }
133
134    fn array_hash<H: Hasher>(array: &ArithmeticArray, state: &mut H, precision: Precision) {
135        array.lhs.array_hash(state, precision);
136        array.rhs.array_hash(state, precision);
137    }
138
139    fn array_eq(array: &ArithmeticArray, other: &ArithmeticArray, precision: Precision) -> bool {
140        array.lhs.array_eq(&other.lhs, precision) && array.rhs.array_eq(&other.rhs, precision)
141    }
142}
143
144impl VisitorVTable<ArithmeticVTable> for ArithmeticVTable {
145    fn visit_buffers(_array: &ArithmeticArray, _visitor: &mut dyn ArrayBufferVisitor) {
146        // No buffers
147    }
148
149    fn visit_children(array: &ArithmeticArray, visitor: &mut dyn ArrayChildVisitor) {
150        visitor.visit_child("lhs", array.lhs.as_ref());
151        visitor.visit_child("rhs", array.rhs.as_ref());
152    }
153}
154
155impl SerdeVTable<ArithmeticVTable> for ArithmeticVTable {
156    type Metadata = EmptyMetadata;
157
158    fn metadata(_array: &ArithmeticArray) -> VortexResult<Option<Self::Metadata>> {
159        Ok(Some(EmptyMetadata))
160    }
161
162    fn build(
163        encoding: &ArithmeticEncoding,
164        dtype: &DType,
165        len: usize,
166        _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
167        buffers: &[ByteBuffer],
168        children: &dyn ArrayChildren,
169    ) -> VortexResult<ArithmeticArray> {
170        assert!(buffers.is_empty());
171
172        Ok(ArithmeticArray::new(
173            children.get(0, dtype, len)?,
174            children.get(1, dtype, len)?,
175            encoding.operator,
176        ))
177    }
178}
179
180impl OperatorVTable<ArithmeticVTable> for ArithmeticVTable {
181    fn reduce_children(array: &ArithmeticArray) -> VortexResult<Option<ArrayRef>> {
182        match (array.lhs.as_constant(), array.rhs.as_constant()) {
183            // If both sides are constant, we compute the value now.
184            (Some(lhs), Some(rhs)) => {
185                let op: vortex_scalar::NumericOperator = match array.operator() {
186                    ArithmeticOperator::Add => vortex_scalar::NumericOperator::Add,
187                    ArithmeticOperator::Sub => vortex_scalar::NumericOperator::Sub,
188                    ArithmeticOperator::Mul => vortex_scalar::NumericOperator::Mul,
189                    ArithmeticOperator::Div => vortex_scalar::NumericOperator::Div,
190                };
191                let result = lhs
192                    .as_primitive()
193                    .checked_binary_numeric(&rhs.as_primitive(), op)
194                    .ok_or_else(|| {
195                        vortex_err!("Constant arithmetic operation resulted in overflow")
196                    })?;
197                return Ok(Some(
198                    ConstantArray::new(Scalar::from(result), array.len()).into_array(),
199                ));
200            }
201            // If either side is constant null, the result is constant null.
202            (Some(lhs), _) if lhs.is_null() => {
203                return Ok(Some(
204                    ConstantArray::new(Scalar::null(array.dtype().clone()), array.len())
205                        .into_array(),
206                ));
207            }
208            (_, Some(rhs)) if rhs.is_null() => {
209                return Ok(Some(
210                    ConstantArray::new(Scalar::null(array.dtype().clone()), array.len())
211                        .into_array(),
212                ));
213            }
214            _ => {}
215        }
216
217        Ok(None)
218    }
219
220    fn bind(
221        array: &ArithmeticArray,
222        selection: Option<&ArrayRef>,
223        ctx: &mut dyn BindCtx,
224    ) -> VortexResult<BatchKernelRef> {
225        // Optimize for constant RHS
226        if let Some(rhs_scalar) = array.rhs.as_constant() {
227            if rhs_scalar.is_null() {
228                // If the RHS is null, the result is always null.
229                return ConstantArray::new(Scalar::null(array.dtype().clone()), array.len())
230                    .into_array()
231                    .bind(selection, ctx);
232            }
233
234            let lhs = ctx.bind(&array.lhs, selection)?;
235            return match_each_native_ptype!(
236                    array.dtype().as_ptype(),
237                    integral: |T| {
238                        let rhs: T = rhs_scalar
239                            .as_primitive()
240                            .typed_value::<T>()
241                            .vortex_expect("Already checked for null above");
242                        Ok(match array.operator() {
243                            ArithmeticOperator::Add => checked_arithmetic_scalar_kernel::<Add, T>(lhs, rhs),
244                            ArithmeticOperator::Sub => checked_arithmetic_scalar_kernel::<Sub, T>(lhs, rhs),
245                            ArithmeticOperator::Mul => checked_arithmetic_scalar_kernel::<Mul, T>(lhs, rhs),
246                            ArithmeticOperator::Div => checked_arithmetic_scalar_kernel::<Div, T>(lhs, rhs),
247                        })
248                    },
249                    floating: |T| {
250                        let rhs: T = rhs_scalar
251                            .as_primitive()
252                            .typed_value::<T>()
253                            .vortex_expect("Already checked for null above");
254                        Ok(match array.operator() {
255                            ArithmeticOperator::Add => arithmetic_scalar_kernel::<Add, T>(lhs, rhs),
256                            ArithmeticOperator::Sub => arithmetic_scalar_kernel::<Sub, T>(lhs, rhs),
257                            ArithmeticOperator::Mul => arithmetic_scalar_kernel::<Mul, T>(lhs, rhs),
258                            ArithmeticOperator::Div => arithmetic_scalar_kernel::<Div, T>(lhs, rhs),
259                        })
260                    }
261            );
262        }
263
264        let lhs = ctx.bind(&array.lhs, selection)?;
265        let rhs = ctx.bind(&array.rhs, selection)?;
266
267        match_each_native_ptype!(
268            array.dtype().as_ptype(),
269            integral: |T| {
270                Ok(match array.operator() {
271                    ArithmeticOperator::Add => checked_arithmetic_kernel::<Add, T>(lhs, rhs),
272                    ArithmeticOperator::Sub => checked_arithmetic_kernel::<Sub, T>(lhs, rhs),
273                    ArithmeticOperator::Mul => checked_arithmetic_kernel::<Mul, T>(lhs, rhs),
274                    ArithmeticOperator::Div => checked_arithmetic_kernel::<Div, T>(lhs, rhs),
275                })
276            },
277            floating: |T| {
278                Ok(match array.operator() {
279                    ArithmeticOperator::Add => arithmetic_kernel::<Add, T>(lhs, rhs),
280                    ArithmeticOperator::Sub => arithmetic_kernel::<Sub, T>(lhs, rhs),
281                    ArithmeticOperator::Mul => arithmetic_kernel::<Mul, T>(lhs, rhs),
282                    ArithmeticOperator::Div => arithmetic_kernel::<Div, T>(lhs, rhs),
283                })
284            }
285        )
286    }
287}
288
289fn arithmetic_kernel<Op, T>(lhs: BatchKernelRef, rhs: BatchKernelRef) -> BatchKernelRef
290where
291    T: NativePType,
292    Op: Operator<T>,
293{
294    kernel(move || {
295        let lhs = lhs.execute()?.into_primitive().downcast::<T>();
296        let rhs = rhs.execute()?.into_primitive().downcast::<T>();
297        let result = Arithmetic::<Op, _>::eval(lhs, &rhs);
298        Ok(result.into())
299    })
300}
301
302fn arithmetic_scalar_kernel<Op, T>(lhs: BatchKernelRef, rhs: T) -> BatchKernelRef
303where
304    T: NativePType + TryFrom<PValue>,
305    Op: Operator<T>,
306{
307    kernel(move || {
308        let lhs = lhs.execute()?.into_primitive().downcast::<T>();
309        let result = Arithmetic::<Op, _>::eval(lhs, &rhs);
310        Ok(result.into())
311    })
312}
313
314fn checked_arithmetic_kernel<Op, T>(lhs: BatchKernelRef, rhs: BatchKernelRef) -> BatchKernelRef
315where
316    T: NativePType,
317    Op: CheckedOperator<T>,
318    PVector<T>: for<'a> CheckedArithmetic<Op, &'a PVector<T>, Output = PVector<T>>,
319{
320    kernel(move || {
321        let lhs = lhs.execute()?.into_primitive().downcast::<T>();
322        let rhs = rhs.execute()?.into_primitive().downcast::<T>();
323        let result = CheckedArithmetic::<Op, _>::checked_eval(lhs, &rhs)
324            .ok_or_else(|| vortex_err!("Arithmetic operation resulted in overflow"))?;
325        Ok(result.into())
326    })
327}
328
329fn checked_arithmetic_scalar_kernel<Op, T>(lhs: BatchKernelRef, rhs: T) -> BatchKernelRef
330where
331    T: NativePType + TryFrom<PValue>,
332    Op: CheckedOperator<T>,
333    PVector<T>: for<'a> CheckedArithmetic<Op, &'a T, Output = PVector<T>>,
334{
335    kernel(move || {
336        let lhs = lhs.execute()?.into_primitive().downcast::<T>();
337        let result = CheckedArithmetic::<Op, _>::checked_eval(lhs, &rhs)
338            .ok_or_else(|| vortex_err!("Arithmetic operation resulted in overflow"))?;
339        Ok(result.into())
340    })
341}
342
343#[cfg(test)]
344mod tests {
345    use vortex_buffer::{bitbuffer, buffer};
346    use vortex_dtype::PTypeDowncastExt;
347
348    use crate::arrays::PrimitiveArray;
349    use crate::compute::arrays::arithmetic::{ArithmeticArray, ArithmeticOperator};
350    use crate::{ArrayRef, IntoArray};
351
352    fn add(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef {
353        ArithmeticArray::new(lhs, rhs, ArithmeticOperator::Add).into_array()
354    }
355
356    fn sub(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef {
357        ArithmeticArray::new(lhs, rhs, ArithmeticOperator::Sub).into_array()
358    }
359
360    fn mul(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef {
361        ArithmeticArray::new(lhs, rhs, ArithmeticOperator::Mul).into_array()
362    }
363
364    fn div(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef {
365        ArithmeticArray::new(lhs, rhs, ArithmeticOperator::Div).into_array()
366    }
367
368    #[test]
369    fn test_add() {
370        let lhs = PrimitiveArray::from_iter([1u32, 2, 3]).into_array();
371        let rhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array();
372        let result = add(lhs, rhs)
373            .execute()
374            .unwrap()
375            .into_primitive()
376            .downcast::<u32>();
377        assert_eq!(result.elements(), &buffer![11u32, 22, 33]);
378    }
379
380    #[test]
381    fn test_sub() {
382        let lhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array();
383        let rhs = PrimitiveArray::from_iter([1u32, 2, 3]).into_array();
384        let result = sub(lhs, rhs)
385            .execute()
386            .unwrap()
387            .into_primitive()
388            .downcast::<u32>();
389        assert_eq!(result.elements(), &buffer![9u32, 18, 27]);
390    }
391
392    #[test]
393    fn test_mul() {
394        let lhs = PrimitiveArray::from_iter([2u32, 3, 4]).into_array();
395        let rhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array();
396        let result = mul(lhs, rhs)
397            .execute()
398            .unwrap()
399            .into_primitive()
400            .downcast::<u32>();
401        assert_eq!(result.elements(), &buffer![20u32, 60, 120]);
402    }
403
404    #[test]
405    fn test_div() {
406        let lhs = PrimitiveArray::from_iter([100u32, 200, 300]).into_array();
407        let rhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array();
408        let result = div(lhs, rhs)
409            .execute()
410            .unwrap()
411            .into_primitive()
412            .downcast::<u32>();
413        assert_eq!(result.elements(), &buffer![10u32, 10, 10]);
414    }
415
416    #[test]
417    fn test_add_with_selection() {
418        let lhs = PrimitiveArray::from_iter([1u32, 2, 3]).into_array();
419        let rhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array();
420
421        let result = add(lhs, rhs)
422            .execute_with_selection(&bitbuffer![1 0 1].into())
423            .unwrap()
424            .into_primitive()
425            .downcast::<u32>();
426        assert_eq!(result.elements(), &buffer![11u32, 33]);
427    }
428}