1use std::ops::BitAnd;
5
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_dtype::NativePType;
9use vortex_vector::VectorMutOps;
10use vortex_vector::VectorOps;
11use vortex_vector::primitive::PVector;
12use vortex_vector::primitive::PVectorMut;
13
14use crate::arithmetic::CheckedArithmetic;
15
16impl<Op, T> CheckedArithmetic<Op, &PVector<T>> for PVector<T>
18where
19 T: NativePType,
20 PVectorMut<T>: for<'a> CheckedArithmetic<Op, &'a PVector<T>, Output = PVector<T>>,
21 for<'a> &'a PVector<T>: CheckedArithmetic<Op, &'a PVector<T>, Output = PVector<T>>,
22{
23 type Output = PVector<T>;
24
25 fn checked_eval(self, rhs: &PVector<T>) -> Option<Self::Output> {
26 match self.try_into_mut() {
27 Ok(lhs) => CheckedArithmetic::<Op, _>::checked_eval(lhs, rhs),
28 Err(lhs) => CheckedArithmetic::<Op, _>::checked_eval(&lhs, rhs),
29 }
30 }
31}
32
33impl<Op, T> CheckedArithmetic<Op, &PVector<T>> for PVectorMut<T>
35where
36 T: NativePType,
37 BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
38{
39 type Output = PVector<T>;
40
41 fn checked_eval(self, other: &PVector<T>) -> Option<Self::Output> {
42 assert_eq!(self.len(), other.len());
43
44 let (lhs_buffer, lhs_validity) = self.into_parts();
45
46 let validity = lhs_validity.freeze().bitand(other.validity());
49 let elements = CheckedArithmetic::<Op, _>::checked_eval(lhs_buffer, other.elements())?;
50
51 Some(PVector::new(elements, validity))
52 }
53}
54
55impl<Op, T> CheckedArithmetic<Op, &PVector<T>> for &PVector<T>
57where
58 T: NativePType,
59 for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
60{
61 type Output = PVector<T>;
62
63 fn checked_eval(self, rhs: &PVector<T>) -> Option<Self::Output> {
64 assert_eq!(self.len(), rhs.len());
65
66 let validity = self.validity().bitand(rhs.validity());
69
70 let elements = CheckedArithmetic::<Op, _>::checked_eval(self.elements(), rhs.elements())?;
71 Some(PVector::new(elements, validity))
72 }
73}
74
75impl<Op, T> CheckedArithmetic<Op, &T> for PVector<T>
78where
79 T: NativePType,
80 PVectorMut<T>: for<'a> CheckedArithmetic<Op, &'a T, Output = PVector<T>>,
81 for<'a> &'a PVector<T>: CheckedArithmetic<Op, &'a T, Output = PVector<T>>,
82{
83 type Output = PVector<T>;
84
85 fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
86 match self.try_into_mut() {
87 Ok(lhs) => CheckedArithmetic::<Op, _>::checked_eval(lhs, rhs),
88 Err(lhs) => CheckedArithmetic::<Op, _>::checked_eval(&lhs, rhs),
89 }
90 }
91}
92
93impl<Op, T> CheckedArithmetic<Op, &T> for PVectorMut<T>
95where
96 T: NativePType,
97 BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
98{
99 type Output = PVector<T>;
100
101 fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
102 let (lhs_buffer, lhs_validity) = self.into_parts();
103 let validity = lhs_validity.freeze();
104
105 let elements = CheckedArithmetic::<Op, _>::checked_eval(lhs_buffer, rhs)?;
106
107 Some(PVector::new(elements, validity))
108 }
109}
110
111impl<Op, T> CheckedArithmetic<Op, &T> for &PVector<T>
113where
114 T: NativePType,
115 for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
116{
117 type Output = PVector<T>;
118
119 fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
120 let buffer = CheckedArithmetic::<Op, _>::checked_eval(self.elements(), rhs)?;
121 Some(PVector::new(buffer, self.validity().clone()))
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use vortex_buffer::buffer;
128 use vortex_mask::Mask;
129 use vortex_vector::VectorOps;
130 use vortex_vector::primitive::PVector;
131
132 use crate::arithmetic::Add;
133 use crate::arithmetic::CheckedArithmetic;
134 use crate::arithmetic::Div;
135 use crate::arithmetic::Mul;
136 use crate::arithmetic::Sub;
137
138 #[test]
139 fn test_add_pvectors() {
140 let left = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
141 let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
142
143 let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right).unwrap();
144 assert_eq!(result.elements(), &buffer![11u32, 22, 33, 44]);
145 }
146
147 #[test]
148 fn test_add_scalar() {
149 let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
150 let result = CheckedArithmetic::<Add, _>::checked_eval(vec, &10).unwrap();
151 assert_eq!(result.elements(), &buffer![11u32, 12, 13, 14]);
152 }
153
154 #[test]
155 fn test_add_with_nulls() {
156 let left = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
157 let right = PVector::new(buffer![10u32, 20, 30], Mask::new_true(3));
158
159 let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right).unwrap();
160 assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
162 assert_eq!(result.elements(), &buffer![11u32, 22, 33]);
163 }
164
165 #[test]
166 fn test_sub_pvectors() {
167 let left = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
168 let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
169
170 let result = CheckedArithmetic::<Sub, _>::checked_eval(left, &right).unwrap();
171 assert_eq!(result.elements(), &buffer![9u32, 18, 27, 36]);
172 }
173
174 #[test]
175 fn test_sub_scalar() {
176 let vec = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
177 let result = CheckedArithmetic::<Sub, _>::checked_eval(vec, &5).unwrap();
178 assert_eq!(result.elements(), &buffer![5u32, 15, 25, 35]);
179 }
180
181 #[test]
182 fn test_mul_pvectors() {
183 let left = PVector::new(buffer![2u32, 3, 4, 5], Mask::new_true(4));
184 let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
185
186 let result = CheckedArithmetic::<Mul, _>::checked_eval(left, &right).unwrap();
187 assert_eq!(result.elements(), &buffer![20u32, 60, 120, 200]);
188 }
189
190 #[test]
191 fn test_mul_scalar() {
192 let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
193 let result = CheckedArithmetic::<Mul, _>::checked_eval(vec, &10).unwrap();
194 assert_eq!(result.elements(), &buffer![10u32, 20, 30, 40]);
195 }
196
197 #[test]
198 fn test_div_pvectors() {
199 let left = PVector::new(buffer![100u32, 200, 300, 400], Mask::new_true(4));
200 let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
201
202 let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right).unwrap();
203 assert_eq!(result.elements(), &buffer![10u32, 10, 10, 10]);
204 }
205
206 #[test]
207 fn test_div_scalar() {
208 let vec = PVector::new(buffer![100u32, 200, 300, 400], Mask::new_true(4));
209 let result = CheckedArithmetic::<Div, _>::checked_eval(vec, &10).unwrap();
210 assert_eq!(result.elements(), &buffer![10u32, 20, 30, 40]);
211 }
212
213 #[test]
214 fn test_overflow_returns_none() {
215 let left = PVector::new(buffer![u8::MAX, 100], Mask::new_true(2));
216 let right = PVector::new(buffer![1u8, 50], Mask::new_true(2));
217
218 let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right);
219 assert!(result.is_none());
220 }
221
222 #[test]
223 fn test_div_by_zero_returns_none() {
224 let left = PVector::new(buffer![10u32, 20, 30], Mask::new_true(3));
225 let right = PVector::new(buffer![2u32, 0, 3], Mask::new_true(3));
226
227 let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right);
228 assert!(result.is_none());
229 }
230
231 #[test]
232 fn test_scalar_preserves_validity() {
233 let vec = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
234 let result = CheckedArithmetic::<Add, _>::checked_eval(vec, &10).unwrap();
235
236 assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
237 assert_eq!(result.elements(), &buffer![11u32, 12, 13]);
238 }
239}