vortex_compute/arithmetic/
primitive_vector.rs1use vortex_dtype::half::f16;
7use vortex_error::vortex_panic;
8use vortex_vector::PrimitiveDatum;
9use vortex_vector::match_each_float_pvector_pair;
10use vortex_vector::match_each_integer_pvector_pair;
11use vortex_vector::primitive::PVector;
12use vortex_vector::primitive::PrimitiveScalar;
13use vortex_vector::primitive::PrimitiveVector;
14
15use crate::arithmetic::Arithmetic;
16use crate::arithmetic::CheckedArithmetic;
17
18impl<Op> CheckedArithmetic<Op, &PrimitiveVector> for PrimitiveVector
19where
20 for<'a> PVector<i8>: CheckedArithmetic<Op, &'a PVector<i8>, Output = PVector<i8>>,
21 for<'a> PVector<i16>: CheckedArithmetic<Op, &'a PVector<i16>, Output = PVector<i16>>,
22 for<'a> PVector<i32>: CheckedArithmetic<Op, &'a PVector<i32>, Output = PVector<i32>>,
23 for<'a> PVector<i64>: CheckedArithmetic<Op, &'a PVector<i64>, Output = PVector<i64>>,
24 for<'a> PVector<u8>: CheckedArithmetic<Op, &'a PVector<u8>, Output = PVector<u8>>,
25 for<'a> PVector<u16>: CheckedArithmetic<Op, &'a PVector<u16>, Output = PVector<u16>>,
26 for<'a> PVector<u32>: CheckedArithmetic<Op, &'a PVector<u32>, Output = PVector<u32>>,
27 for<'a> PVector<u64>: CheckedArithmetic<Op, &'a PVector<u64>, Output = PVector<u64>>,
28{
29 type Output = PrimitiveVector;
30
31 fn checked_eval(self, rhs: &PrimitiveVector) -> Option<Self::Output> {
32 match_each_integer_pvector_pair!(
33 (self, &rhs),
34 |l, r| { CheckedArithmetic::<Op, _>::checked_eval(l, r).map(Into::into) },
35 { vortex_panic!("dont use checked arithmetic for floats") }
36 )
37 }
38}
39
40impl<Op> Arithmetic<Op, &PrimitiveVector> for PrimitiveVector
41where
42 for<'a> PVector<f16>: Arithmetic<Op, &'a PVector<f16>, Output = PVector<f16>>,
43 for<'a> PVector<f32>: Arithmetic<Op, &'a PVector<f32>, Output = PVector<f32>>,
44 for<'a> PVector<f64>: Arithmetic<Op, &'a PVector<f64>, Output = PVector<f64>>,
45{
46 type Output = PrimitiveVector;
47
48 fn eval(self, rhs: &PrimitiveVector) -> Self::Output {
49 match_each_float_pvector_pair!(
50 (self, rhs),
51 |l, r| { Arithmetic::<Op, _>::eval(l, r).into() },
52 |l, r| {
53 vortex_panic!(
54 "Cannot perform arithmetic on PrimitiveVectors of different types: {:?} and {:?}",
55 l,
56 r
57 )
58 }
59 )
60 }
61}
62
63impl<Op> Arithmetic<Op, &PrimitiveScalar> for PrimitiveVector
66where
67 for<'a> PVector<f16>: Arithmetic<Op, &'a f16, Output = PVector<f16>>,
68 for<'a> PVector<f32>: Arithmetic<Op, &'a f32, Output = PVector<f32>>,
69 for<'a> PVector<f64>: Arithmetic<Op, &'a f64, Output = PVector<f64>>,
70{
71 type Output = PrimitiveDatum;
72
73 fn eval(self, rhs: &PrimitiveScalar) -> Self::Output {
74 match (self, rhs) {
75 (PrimitiveVector::F16(v), PrimitiveScalar::F16(s)) => match s.value() {
76 Some(scalar_val) => {
77 PrimitiveDatum::Vector(Arithmetic::<Op, _>::eval(v, &scalar_val).into())
78 }
79 None => PrimitiveDatum::Scalar(s.clone().into()),
80 },
81 (PrimitiveVector::F32(v), PrimitiveScalar::F32(s)) => match s.value() {
82 Some(scalar_val) => {
83 PrimitiveDatum::Vector(Arithmetic::<Op, _>::eval(v, &scalar_val).into())
84 }
85 None => PrimitiveDatum::Scalar(s.clone().into()),
86 },
87 (PrimitiveVector::F64(v), PrimitiveScalar::F64(s)) => match s.value() {
88 Some(scalar_val) => {
89 PrimitiveDatum::Vector(Arithmetic::<Op, _>::eval(v, &scalar_val).into())
90 }
91 None => PrimitiveDatum::Scalar(s.clone().into()),
92 },
93 (v, s) => vortex_panic!(
94 "Cannot perform arithmetic between vector {:?} and scalar {:?}",
95 v,
96 s
97 ),
98 }
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use vortex_vector::VectorMutOps;
105 use vortex_vector::VectorOps;
106 use vortex_vector::primitive::PVectorMut;
107
108 use super::*;
109 use crate::arithmetic::Add;
110
111 #[test]
112 fn test_checked_add_i32() {
113 let left: PrimitiveVector = PVectorMut::from_iter([1i32, 2, 3].map(Some))
114 .freeze()
115 .into();
116 let right: PrimitiveVector = PVectorMut::from_iter([10i32, 20, 30].map(Some))
117 .freeze()
118 .into();
119
120 let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right).unwrap();
121 if let PrimitiveVector::I32(v) = result {
122 assert_eq!(v.scalar_at(0).value(), Some(11));
123 assert_eq!(v.scalar_at(1).value(), Some(22));
124 assert_eq!(v.scalar_at(2).value(), Some(33));
125 } else {
126 panic!("Expected I32 result");
127 }
128 }
129
130 #[test]
131 fn test_float_add() {
132 let left: PrimitiveVector = PVectorMut::from_iter([1.0f64, 2.0, 3.0].map(Some))
133 .freeze()
134 .into();
135 let right: PrimitiveVector = PVectorMut::from_iter([0.5f64, 0.5, 0.5].map(Some))
136 .freeze()
137 .into();
138
139 let result = Arithmetic::<Add, _>::eval(left, &right);
140 if let PrimitiveVector::F64(v) = result {
141 assert_eq!(v.scalar_at(0).value(), Some(1.5));
142 assert_eq!(v.scalar_at(1).value(), Some(2.5));
143 assert_eq!(v.scalar_at(2).value(), Some(3.5));
144 } else {
145 panic!("Expected F64 result");
146 }
147 }
148}