vortex_compute/arithmetic/
primitive_vector.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Arithmetic implementations for PrimitiveVector enum.
5
6use vortex_dtype::half::f16;
7use vortex_error::vortex_panic;
8use vortex_vector::PrimitiveDatum;
9use vortex_vector::match_each_float_pvector_pair;
10use vortex_vector::match_each_integer_pvector_pair;
11use vortex_vector::primitive::PVector;
12use vortex_vector::primitive::PrimitiveScalar;
13use vortex_vector::primitive::PrimitiveVector;
14
15use crate::arithmetic::Arithmetic;
16use crate::arithmetic::CheckedArithmetic;
17
18impl<Op> CheckedArithmetic<Op, &PrimitiveVector> for PrimitiveVector
19where
20    for<'a> PVector<i8>: CheckedArithmetic<Op, &'a PVector<i8>, Output = PVector<i8>>,
21    for<'a> PVector<i16>: CheckedArithmetic<Op, &'a PVector<i16>, Output = PVector<i16>>,
22    for<'a> PVector<i32>: CheckedArithmetic<Op, &'a PVector<i32>, Output = PVector<i32>>,
23    for<'a> PVector<i64>: CheckedArithmetic<Op, &'a PVector<i64>, Output = PVector<i64>>,
24    for<'a> PVector<u8>: CheckedArithmetic<Op, &'a PVector<u8>, Output = PVector<u8>>,
25    for<'a> PVector<u16>: CheckedArithmetic<Op, &'a PVector<u16>, Output = PVector<u16>>,
26    for<'a> PVector<u32>: CheckedArithmetic<Op, &'a PVector<u32>, Output = PVector<u32>>,
27    for<'a> PVector<u64>: CheckedArithmetic<Op, &'a PVector<u64>, Output = PVector<u64>>,
28{
29    type Output = PrimitiveVector;
30
31    fn checked_eval(self, rhs: &PrimitiveVector) -> Option<Self::Output> {
32        match_each_integer_pvector_pair!(
33            (self, &rhs),
34            |l, r| { CheckedArithmetic::<Op, _>::checked_eval(l, r).map(Into::into) },
35            { vortex_panic!("dont use checked arithmetic for floats") }
36        )
37    }
38}
39
40impl<Op> Arithmetic<Op, &PrimitiveVector> for PrimitiveVector
41where
42    for<'a> PVector<f16>: Arithmetic<Op, &'a PVector<f16>, Output = PVector<f16>>,
43    for<'a> PVector<f32>: Arithmetic<Op, &'a PVector<f32>, Output = PVector<f32>>,
44    for<'a> PVector<f64>: Arithmetic<Op, &'a PVector<f64>, Output = PVector<f64>>,
45{
46    type Output = PrimitiveVector;
47
48    fn eval(self, rhs: &PrimitiveVector) -> Self::Output {
49        match_each_float_pvector_pair!(
50            (self, rhs),
51            |l, r| { Arithmetic::<Op, _>::eval(l, r).into() },
52            |l, r| {
53                vortex_panic!(
54                    "Cannot perform arithmetic on PrimitiveVectors of different types: {:?} and {:?}",
55                    l,
56                    r
57                )
58            }
59        )
60    }
61}
62
63/// Vector on LHS, Scalar on RHS - modifies vector in place.
64/// Returns a scalar if the input scalar is null.
65impl<Op> Arithmetic<Op, &PrimitiveScalar> for PrimitiveVector
66where
67    for<'a> PVector<f16>: Arithmetic<Op, &'a f16, Output = PVector<f16>>,
68    for<'a> PVector<f32>: Arithmetic<Op, &'a f32, Output = PVector<f32>>,
69    for<'a> PVector<f64>: Arithmetic<Op, &'a f64, Output = PVector<f64>>,
70{
71    type Output = PrimitiveDatum;
72
73    fn eval(self, rhs: &PrimitiveScalar) -> Self::Output {
74        match (self, rhs) {
75            (PrimitiveVector::F16(v), PrimitiveScalar::F16(s)) => match s.value() {
76                Some(scalar_val) => {
77                    PrimitiveDatum::Vector(Arithmetic::<Op, _>::eval(v, &scalar_val).into())
78                }
79                None => PrimitiveDatum::Scalar(s.clone().into()),
80            },
81            (PrimitiveVector::F32(v), PrimitiveScalar::F32(s)) => match s.value() {
82                Some(scalar_val) => {
83                    PrimitiveDatum::Vector(Arithmetic::<Op, _>::eval(v, &scalar_val).into())
84                }
85                None => PrimitiveDatum::Scalar(s.clone().into()),
86            },
87            (PrimitiveVector::F64(v), PrimitiveScalar::F64(s)) => match s.value() {
88                Some(scalar_val) => {
89                    PrimitiveDatum::Vector(Arithmetic::<Op, _>::eval(v, &scalar_val).into())
90                }
91                None => PrimitiveDatum::Scalar(s.clone().into()),
92            },
93            (v, s) => vortex_panic!(
94                "Cannot perform arithmetic between vector {:?} and scalar {:?}",
95                v,
96                s
97            ),
98        }
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use vortex_vector::VectorMutOps;
105    use vortex_vector::VectorOps;
106    use vortex_vector::primitive::PVectorMut;
107
108    use super::*;
109    use crate::arithmetic::Add;
110
111    #[test]
112    fn test_checked_add_i32() {
113        let left: PrimitiveVector = PVectorMut::from_iter([1i32, 2, 3].map(Some))
114            .freeze()
115            .into();
116        let right: PrimitiveVector = PVectorMut::from_iter([10i32, 20, 30].map(Some))
117            .freeze()
118            .into();
119
120        let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right).unwrap();
121        if let PrimitiveVector::I32(v) = result {
122            assert_eq!(v.scalar_at(0).value(), Some(11));
123            assert_eq!(v.scalar_at(1).value(), Some(22));
124            assert_eq!(v.scalar_at(2).value(), Some(33));
125        } else {
126            panic!("Expected I32 result");
127        }
128    }
129
130    #[test]
131    fn test_float_add() {
132        let left: PrimitiveVector = PVectorMut::from_iter([1.0f64, 2.0, 3.0].map(Some))
133            .freeze()
134            .into();
135        let right: PrimitiveVector = PVectorMut::from_iter([0.5f64, 0.5, 0.5].map(Some))
136            .freeze()
137            .into();
138
139        let result = Arithmetic::<Add, _>::eval(left, &right);
140        if let PrimitiveVector::F64(v) = result {
141            assert_eq!(v.scalar_at(0).value(), Some(1.5));
142            assert_eq!(v.scalar_at(1).value(), Some(2.5));
143            assert_eq!(v.scalar_at(2).value(), Some(3.5));
144        } else {
145            panic!("Expected F64 result");
146        }
147    }
148}