1use std::ops::BitAnd;
5
6use vortex_buffer::{Buffer, BufferMut};
7use vortex_dtype::NativePType;
8use vortex_vector::primitive::{PVector, PVectorMut};
9use vortex_vector::{VectorMutOps, VectorOps};
10
11use crate::arithmetic::CheckedArithmetic;
12
13impl<Op, T> CheckedArithmetic<Op, &PVector<T>> for PVector<T>
15where
16 T: NativePType,
17 PVectorMut<T>: for<'a> CheckedArithmetic<Op, &'a PVector<T>, Output = PVector<T>>,
18 for<'a> &'a PVector<T>: CheckedArithmetic<Op, &'a PVector<T>, Output = PVector<T>>,
19{
20 type Output = PVector<T>;
21
22 fn checked_eval(self, rhs: &PVector<T>) -> Option<Self::Output> {
23 match self.try_into_mut() {
24 Ok(lhs) => CheckedArithmetic::<Op, _>::checked_eval(lhs, rhs),
25 Err(lhs) => CheckedArithmetic::<Op, _>::checked_eval(&lhs, rhs),
26 }
27 }
28}
29
30impl<Op, T> CheckedArithmetic<Op, &PVector<T>> for PVectorMut<T>
32where
33 T: NativePType,
34 BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
35{
36 type Output = PVector<T>;
37
38 fn checked_eval(self, other: &PVector<T>) -> Option<Self::Output> {
39 assert_eq!(self.len(), other.len());
40
41 let (lhs_buffer, lhs_validity) = self.into_parts();
42
43 let validity = lhs_validity.freeze().bitand(other.validity());
46 let elements = CheckedArithmetic::<Op, _>::checked_eval(lhs_buffer, other.elements())?;
47
48 Some(PVector::new(elements, validity))
49 }
50}
51
52impl<Op, T> CheckedArithmetic<Op, &PVector<T>> for &PVector<T>
54where
55 T: NativePType,
56 for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
57{
58 type Output = PVector<T>;
59
60 fn checked_eval(self, rhs: &PVector<T>) -> Option<Self::Output> {
61 assert_eq!(self.len(), rhs.len());
62
63 let validity = self.validity().bitand(rhs.validity());
66
67 let elements = CheckedArithmetic::<Op, _>::checked_eval(self.elements(), rhs.elements())?;
68 Some(PVector::new(elements, validity))
69 }
70}
71
72impl<Op, T> CheckedArithmetic<Op, &T> for PVector<T>
75where
76 T: NativePType,
77 PVectorMut<T>: for<'a> CheckedArithmetic<Op, &'a T, Output = PVector<T>>,
78 for<'a> &'a PVector<T>: CheckedArithmetic<Op, &'a T, Output = PVector<T>>,
79{
80 type Output = PVector<T>;
81
82 fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
83 match self.try_into_mut() {
84 Ok(lhs) => CheckedArithmetic::<Op, _>::checked_eval(lhs, rhs),
85 Err(lhs) => CheckedArithmetic::<Op, _>::checked_eval(&lhs, rhs),
86 }
87 }
88}
89
90impl<Op, T> CheckedArithmetic<Op, &T> for PVectorMut<T>
92where
93 T: NativePType,
94 BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
95{
96 type Output = PVector<T>;
97
98 fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
99 let (lhs_buffer, lhs_validity) = self.into_parts();
100 let validity = lhs_validity.freeze();
101
102 let elements = CheckedArithmetic::<Op, _>::checked_eval(lhs_buffer, rhs)?;
103
104 Some(PVector::new(elements, validity))
105 }
106}
107
108impl<Op, T> CheckedArithmetic<Op, &T> for &PVector<T>
110where
111 T: NativePType,
112 for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
113{
114 type Output = PVector<T>;
115
116 fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
117 let buffer = CheckedArithmetic::<Op, _>::checked_eval(self.elements(), rhs)?;
118 Some(PVector::new(buffer, self.validity().clone()))
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use vortex_buffer::buffer;
125 use vortex_mask::Mask;
126 use vortex_vector::VectorOps;
127 use vortex_vector::primitive::PVector;
128
129 use crate::arithmetic::{Add, CheckedArithmetic, Div, Mul, Sub};
130
131 #[test]
132 fn test_add_pvectors() {
133 let left = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
134 let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
135
136 let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right).unwrap();
137 assert_eq!(result.elements(), &buffer![11u32, 22, 33, 44]);
138 }
139
140 #[test]
141 fn test_add_scalar() {
142 let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
143 let result = CheckedArithmetic::<Add, _>::checked_eval(vec, &10).unwrap();
144 assert_eq!(result.elements(), &buffer![11u32, 12, 13, 14]);
145 }
146
147 #[test]
148 fn test_add_with_nulls() {
149 let left = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
150 let right = PVector::new(buffer![10u32, 20, 30], Mask::new_true(3));
151
152 let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right).unwrap();
153 assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
155 assert_eq!(result.elements(), &buffer![11u32, 22, 33]);
156 }
157
158 #[test]
159 fn test_sub_pvectors() {
160 let left = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
161 let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
162
163 let result = CheckedArithmetic::<Sub, _>::checked_eval(left, &right).unwrap();
164 assert_eq!(result.elements(), &buffer![9u32, 18, 27, 36]);
165 }
166
167 #[test]
168 fn test_sub_scalar() {
169 let vec = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
170 let result = CheckedArithmetic::<Sub, _>::checked_eval(vec, &5).unwrap();
171 assert_eq!(result.elements(), &buffer![5u32, 15, 25, 35]);
172 }
173
174 #[test]
175 fn test_mul_pvectors() {
176 let left = PVector::new(buffer![2u32, 3, 4, 5], Mask::new_true(4));
177 let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
178
179 let result = CheckedArithmetic::<Mul, _>::checked_eval(left, &right).unwrap();
180 assert_eq!(result.elements(), &buffer![20u32, 60, 120, 200]);
181 }
182
183 #[test]
184 fn test_mul_scalar() {
185 let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
186 let result = CheckedArithmetic::<Mul, _>::checked_eval(vec, &10).unwrap();
187 assert_eq!(result.elements(), &buffer![10u32, 20, 30, 40]);
188 }
189
190 #[test]
191 fn test_div_pvectors() {
192 let left = PVector::new(buffer![100u32, 200, 300, 400], Mask::new_true(4));
193 let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
194
195 let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right).unwrap();
196 assert_eq!(result.elements(), &buffer![10u32, 10, 10, 10]);
197 }
198
199 #[test]
200 fn test_div_scalar() {
201 let vec = PVector::new(buffer![100u32, 200, 300, 400], Mask::new_true(4));
202 let result = CheckedArithmetic::<Div, _>::checked_eval(vec, &10).unwrap();
203 assert_eq!(result.elements(), &buffer![10u32, 20, 30, 40]);
204 }
205
206 #[test]
207 fn test_overflow_returns_none() {
208 let left = PVector::new(buffer![u8::MAX, 100], Mask::new_true(2));
209 let right = PVector::new(buffer![1u8, 50], Mask::new_true(2));
210
211 let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right);
212 assert!(result.is_none());
213 }
214
215 #[test]
216 fn test_div_by_zero_returns_none() {
217 let left = PVector::new(buffer![10u32, 20, 30], Mask::new_true(3));
218 let right = PVector::new(buffer![2u32, 0, 3], Mask::new_true(3));
219
220 let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right);
221 assert!(result.is_none());
222 }
223
224 #[test]
225 fn test_scalar_preserves_validity() {
226 let vec = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
227 let result = CheckedArithmetic::<Add, _>::checked_eval(vec, &10).unwrap();
228
229 assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
230 assert_eq!(result.elements(), &buffer![11u32, 12, 13]);
231 }
232}