vortex_compute/arithmetic/
primitive_scalar.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Arithmetic implementations for PrimitiveScalar enum.
5
6use vortex_dtype::half::f16;
7use vortex_error::vortex_panic;
8use vortex_vector::PrimitiveDatum;
9use vortex_vector::match_each_float_pscalar_pair;
10use vortex_vector::match_each_integer_pscalar_pair;
11use vortex_vector::primitive::PScalar;
12use vortex_vector::primitive::PVector;
13use vortex_vector::primitive::PrimitiveScalar;
14use vortex_vector::primitive::PrimitiveVector;
15
16use crate::arithmetic::Arithmetic;
17use crate::arithmetic::CheckedArithmetic;
18
19impl<Op> CheckedArithmetic<Op> for &PrimitiveScalar
20where
21    for<'a> &'a PScalar<i8>: CheckedArithmetic<Op, &'a PScalar<i8>, Output = PScalar<i8>>,
22    for<'a> &'a PScalar<i16>: CheckedArithmetic<Op, &'a PScalar<i16>, Output = PScalar<i16>>,
23    for<'a> &'a PScalar<i32>: CheckedArithmetic<Op, &'a PScalar<i32>, Output = PScalar<i32>>,
24    for<'a> &'a PScalar<i64>: CheckedArithmetic<Op, &'a PScalar<i64>, Output = PScalar<i64>>,
25    for<'a> &'a PScalar<u8>: CheckedArithmetic<Op, &'a PScalar<u8>, Output = PScalar<u8>>,
26    for<'a> &'a PScalar<u16>: CheckedArithmetic<Op, &'a PScalar<u16>, Output = PScalar<u16>>,
27    for<'a> &'a PScalar<u32>: CheckedArithmetic<Op, &'a PScalar<u32>, Output = PScalar<u32>>,
28    for<'a> &'a PScalar<u64>: CheckedArithmetic<Op, &'a PScalar<u64>, Output = PScalar<u64>>,
29{
30    type Output = PrimitiveScalar;
31
32    fn checked_eval(self, rhs: &PrimitiveScalar) -> Option<Self::Output> {
33        match_each_integer_pscalar_pair!(
34            (&self, &rhs),
35            |l, r| { CheckedArithmetic::<Op, _>::checked_eval(l, r).map(Into::into) },
36            { vortex_panic!("cannot compare primitive scalar of different types") }
37        )
38    }
39}
40
41impl<Op> Arithmetic<Op, &PrimitiveScalar> for &PrimitiveScalar
42where
43    for<'a> &'a PScalar<f16>: Arithmetic<Op, &'a PScalar<f16>, Output = PScalar<f16>>,
44    for<'a> &'a PScalar<f32>: Arithmetic<Op, &'a PScalar<f32>, Output = PScalar<f32>>,
45    for<'a> &'a PScalar<f64>: Arithmetic<Op, &'a PScalar<f64>, Output = PScalar<f64>>,
46{
47    type Output = PrimitiveScalar;
48
49    fn eval(self, rhs: &PrimitiveScalar) -> Self::Output {
50        match_each_float_pscalar_pair!(
51            (self, rhs),
52            |l, r| { Arithmetic::<Op, _>::eval(l, r).into() },
53            {
54                vortex_panic!(
55                    "Cannot perform arithmetic on PrimitiveScalars of different types: {:?} and {:?}",
56                    self,
57                    rhs
58                )
59            }
60        )
61    }
62}
63
64/// Scalar on LHS, owned Vector on RHS - modifies vector in place.
65/// Returns a scalar if the input scalar is null.
66impl<Op> Arithmetic<Op, PrimitiveVector> for &PrimitiveScalar
67where
68    for<'a> &'a f16: Arithmetic<Op, PVector<f16>, Output = PVector<f16>>,
69    for<'a> &'a f32: Arithmetic<Op, PVector<f32>, Output = PVector<f32>>,
70    for<'a> &'a f64: Arithmetic<Op, PVector<f64>, Output = PVector<f64>>,
71{
72    type Output = PrimitiveDatum;
73
74    fn eval(self, rhs: PrimitiveVector) -> Self::Output {
75        match (self, rhs) {
76            (PrimitiveScalar::F16(s), PrimitiveVector::F16(v)) => match s.value() {
77                Some(scalar_val) => {
78                    PrimitiveDatum::Vector(Arithmetic::<Op, _>::eval(&scalar_val, v).into())
79                }
80                None => PrimitiveDatum::Scalar(s.clone().into()),
81            },
82            (PrimitiveScalar::F32(s), PrimitiveVector::F32(v)) => match s.value() {
83                Some(scalar_val) => {
84                    PrimitiveDatum::Vector(Arithmetic::<Op, _>::eval(&scalar_val, v).into())
85                }
86                None => PrimitiveDatum::Scalar(s.clone().into()),
87            },
88            (PrimitiveScalar::F64(s), PrimitiveVector::F64(v)) => match s.value() {
89                Some(scalar_val) => {
90                    PrimitiveDatum::Vector(Arithmetic::<Op, _>::eval(&scalar_val, v).into())
91                }
92                None => PrimitiveDatum::Scalar(s.clone().into()),
93            },
94            (s, v) => vortex_panic!(
95                "Cannot perform arithmetic between scalar {:?} and vector {:?}",
96                s,
97                v
98            ),
99        }
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use vortex_vector::primitive::PScalar;
106
107    use super::*;
108    use crate::arithmetic::Add;
109
110    #[test]
111    fn test_checked_add_i32() {
112        let left: PrimitiveScalar = PScalar::new(Some(5i32)).into();
113        let right: PrimitiveScalar = PScalar::new(Some(3i32)).into();
114
115        let result = CheckedArithmetic::<Add, _>::checked_eval(&left, &right).unwrap();
116        if let PrimitiveScalar::I32(v) = result {
117            assert_eq!(v.value(), Some(8));
118        } else {
119            panic!("Expected I32 result");
120        }
121    }
122
123    #[test]
124    fn test_checked_add_overflow() {
125        let left: PrimitiveScalar = PScalar::new(Some(i32::MAX)).into();
126        let right: PrimitiveScalar = PScalar::new(Some(1i32)).into();
127
128        let result = CheckedArithmetic::<Add, _>::checked_eval(&left, &right);
129        assert!(result.is_none());
130    }
131
132    #[test]
133    fn test_float_add() {
134        let left: PrimitiveScalar = PScalar::new(Some(1.5f64)).into();
135        let right: PrimitiveScalar = PScalar::new(Some(2.5f64)).into();
136
137        let result = Arithmetic::<Add, _>::eval(&left, &right);
138        if let PrimitiveScalar::F64(v) = result {
139            assert_eq!(v.value(), Some(4.0));
140        } else {
141            panic!("Expected F64 result");
142        }
143    }
144}