vortex_compute/arithmetic/
pvector_checked.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::BitAnd;
5
6use vortex_buffer::{Buffer, BufferMut};
7use vortex_dtype::NativePType;
8use vortex_vector::primitive::{PVector, PVectorMut};
9use vortex_vector::{VectorMutOps, VectorOps};
10
11use crate::arithmetic::CheckedArithmetic;
12
13/// Implementation that attempts to downcast to a mutable vector and operates in-place.
14impl<Op, T> CheckedArithmetic<Op, &PVector<T>> for PVector<T>
15where
16    T: NativePType,
17    PVectorMut<T>: for<'a> CheckedArithmetic<Op, &'a PVector<T>, Output = PVector<T>>,
18    for<'a> &'a PVector<T>: CheckedArithmetic<Op, &'a PVector<T>, Output = PVector<T>>,
19{
20    type Output = PVector<T>;
21
22    fn checked_eval(self, rhs: &PVector<T>) -> Option<Self::Output> {
23        match self.try_into_mut() {
24            Ok(lhs) => CheckedArithmetic::<Op, _>::checked_eval(lhs, rhs),
25            Err(lhs) => CheckedArithmetic::<Op, _>::checked_eval(&lhs, rhs),
26        }
27    }
28}
29
30/// Implementation that operates in-place over a mutable vector.
31impl<Op, T> CheckedArithmetic<Op, &PVector<T>> for PVectorMut<T>
32where
33    T: NativePType,
34    BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
35{
36    type Output = PVector<T>;
37
38    fn checked_eval(self, other: &PVector<T>) -> Option<Self::Output> {
39        assert_eq!(self.len(), other.len());
40
41        let (lhs_buffer, lhs_validity) = self.into_parts();
42
43        // TODO(ngates): based on the true count of the validity, we may wish to short-circuit here
44        //  or choose a different implementation.
45        let validity = lhs_validity.freeze().bitand(other.validity());
46        let elements = CheckedArithmetic::<Op, _>::checked_eval(lhs_buffer, other.elements())?;
47
48        Some(PVector::new(elements, validity))
49    }
50}
51
52/// Implementation that allocates a new output vector.
53impl<Op, T> CheckedArithmetic<Op, &PVector<T>> for &PVector<T>
54where
55    T: NativePType,
56    for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
57{
58    type Output = PVector<T>;
59
60    fn checked_eval(self, rhs: &PVector<T>) -> Option<Self::Output> {
61        assert_eq!(self.len(), rhs.len());
62
63        // TODO(ngates): based on the true count of the validity, we may wish to short-circuit here
64        //  or choose a different implementation.
65        let validity = self.validity().bitand(rhs.validity());
66
67        let elements = CheckedArithmetic::<Op, _>::checked_eval(self.elements(), rhs.elements())?;
68        Some(PVector::new(elements, validity))
69    }
70}
71
72/// Implementation that attempts to downcast to a mutable vector and operates in-place against
73/// a scalar RHS value.
74impl<Op, T> CheckedArithmetic<Op, &T> for PVector<T>
75where
76    T: NativePType,
77    PVectorMut<T>: for<'a> CheckedArithmetic<Op, &'a T, Output = PVector<T>>,
78    for<'a> &'a PVector<T>: CheckedArithmetic<Op, &'a T, Output = PVector<T>>,
79{
80    type Output = PVector<T>;
81
82    fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
83        match self.try_into_mut() {
84            Ok(lhs) => CheckedArithmetic::<Op, _>::checked_eval(lhs, rhs),
85            Err(lhs) => CheckedArithmetic::<Op, _>::checked_eval(&lhs, rhs),
86        }
87    }
88}
89
90/// Implementation that operates in-place over a mutable vector against a scalar RHS value.
91impl<Op, T> CheckedArithmetic<Op, &T> for PVectorMut<T>
92where
93    T: NativePType,
94    BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
95{
96    type Output = PVector<T>;
97
98    fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
99        let (lhs_buffer, lhs_validity) = self.into_parts();
100        let validity = lhs_validity.freeze();
101
102        let elements = CheckedArithmetic::<Op, _>::checked_eval(lhs_buffer, rhs)?;
103
104        Some(PVector::new(elements, validity))
105    }
106}
107
108/// Implementation that allocates a new output vector against a scalar RHS value.
109impl<Op, T> CheckedArithmetic<Op, &T> for &PVector<T>
110where
111    T: NativePType,
112    for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
113{
114    type Output = PVector<T>;
115
116    fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
117        let buffer = CheckedArithmetic::<Op, _>::checked_eval(self.elements(), rhs)?;
118        Some(PVector::new(buffer, self.validity().clone()))
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use vortex_buffer::buffer;
125    use vortex_mask::Mask;
126    use vortex_vector::VectorOps;
127    use vortex_vector::primitive::PVector;
128
129    use crate::arithmetic::{Add, CheckedArithmetic, Div, Mul, Sub};
130
131    #[test]
132    fn test_add_pvectors() {
133        let left = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
134        let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
135
136        let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right).unwrap();
137        assert_eq!(result.elements(), &buffer![11u32, 22, 33, 44]);
138    }
139
140    #[test]
141    fn test_add_scalar() {
142        let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
143        let result = CheckedArithmetic::<Add, _>::checked_eval(vec, &10).unwrap();
144        assert_eq!(result.elements(), &buffer![11u32, 12, 13, 14]);
145    }
146
147    #[test]
148    fn test_add_with_nulls() {
149        let left = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
150        let right = PVector::new(buffer![10u32, 20, 30], Mask::new_true(3));
151
152        let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right).unwrap();
153        // Validity is AND'd, so if either side is null, result is null
154        assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
155        assert_eq!(result.elements(), &buffer![11u32, 22, 33]);
156    }
157
158    #[test]
159    fn test_sub_pvectors() {
160        let left = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
161        let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
162
163        let result = CheckedArithmetic::<Sub, _>::checked_eval(left, &right).unwrap();
164        assert_eq!(result.elements(), &buffer![9u32, 18, 27, 36]);
165    }
166
167    #[test]
168    fn test_sub_scalar() {
169        let vec = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
170        let result = CheckedArithmetic::<Sub, _>::checked_eval(vec, &5).unwrap();
171        assert_eq!(result.elements(), &buffer![5u32, 15, 25, 35]);
172    }
173
174    #[test]
175    fn test_mul_pvectors() {
176        let left = PVector::new(buffer![2u32, 3, 4, 5], Mask::new_true(4));
177        let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
178
179        let result = CheckedArithmetic::<Mul, _>::checked_eval(left, &right).unwrap();
180        assert_eq!(result.elements(), &buffer![20u32, 60, 120, 200]);
181    }
182
183    #[test]
184    fn test_mul_scalar() {
185        let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
186        let result = CheckedArithmetic::<Mul, _>::checked_eval(vec, &10).unwrap();
187        assert_eq!(result.elements(), &buffer![10u32, 20, 30, 40]);
188    }
189
190    #[test]
191    fn test_div_pvectors() {
192        let left = PVector::new(buffer![100u32, 200, 300, 400], Mask::new_true(4));
193        let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
194
195        let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right).unwrap();
196        assert_eq!(result.elements(), &buffer![10u32, 10, 10, 10]);
197    }
198
199    #[test]
200    fn test_div_scalar() {
201        let vec = PVector::new(buffer![100u32, 200, 300, 400], Mask::new_true(4));
202        let result = CheckedArithmetic::<Div, _>::checked_eval(vec, &10).unwrap();
203        assert_eq!(result.elements(), &buffer![10u32, 20, 30, 40]);
204    }
205
206    #[test]
207    fn test_overflow_returns_none() {
208        let left = PVector::new(buffer![u8::MAX, 100], Mask::new_true(2));
209        let right = PVector::new(buffer![1u8, 50], Mask::new_true(2));
210
211        let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right);
212        assert!(result.is_none());
213    }
214
215    #[test]
216    fn test_div_by_zero_returns_none() {
217        let left = PVector::new(buffer![10u32, 20, 30], Mask::new_true(3));
218        let right = PVector::new(buffer![2u32, 0, 3], Mask::new_true(3));
219
220        let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right);
221        assert!(result.is_none());
222    }
223
224    #[test]
225    fn test_scalar_preserves_validity() {
226        let vec = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
227        let result = CheckedArithmetic::<Add, _>::checked_eval(vec, &10).unwrap();
228
229        assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
230        assert_eq!(result.elements(), &buffer![11u32, 12, 13]);
231    }
232}