vortex_compute/arithmetic/
pvector_checked.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::BitAnd;
5
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_dtype::NativePType;
9use vortex_vector::VectorMutOps;
10use vortex_vector::VectorOps;
11use vortex_vector::primitive::PVector;
12use vortex_vector::primitive::PVectorMut;
13
14use crate::arithmetic::CheckedArithmetic;
15
16/// Implementation that attempts to downcast to a mutable vector and operates in-place.
17impl<Op, T> CheckedArithmetic<Op, &PVector<T>> for PVector<T>
18where
19    T: NativePType,
20    PVectorMut<T>: for<'a> CheckedArithmetic<Op, &'a PVector<T>, Output = PVector<T>>,
21    for<'a> &'a PVector<T>: CheckedArithmetic<Op, &'a PVector<T>, Output = PVector<T>>,
22{
23    type Output = PVector<T>;
24
25    fn checked_eval(self, rhs: &PVector<T>) -> Option<Self::Output> {
26        match self.try_into_mut() {
27            Ok(lhs) => CheckedArithmetic::<Op, _>::checked_eval(lhs, rhs),
28            Err(lhs) => CheckedArithmetic::<Op, _>::checked_eval(&lhs, rhs),
29        }
30    }
31}
32
33/// Implementation that operates in-place over a mutable vector.
34impl<Op, T> CheckedArithmetic<Op, &PVector<T>> for PVectorMut<T>
35where
36    T: NativePType,
37    BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
38{
39    type Output = PVector<T>;
40
41    fn checked_eval(self, other: &PVector<T>) -> Option<Self::Output> {
42        assert_eq!(self.len(), other.len());
43
44        let (lhs_buffer, lhs_validity) = self.into_parts();
45
46        // TODO(ngates): based on the true count of the validity, we may wish to short-circuit here
47        //  or choose a different implementation.
48        let validity = lhs_validity.freeze().bitand(other.validity());
49        let elements = CheckedArithmetic::<Op, _>::checked_eval(lhs_buffer, other.elements())?;
50
51        Some(PVector::new(elements, validity))
52    }
53}
54
55/// Implementation that allocates a new output vector.
56impl<Op, T> CheckedArithmetic<Op, &PVector<T>> for &PVector<T>
57where
58    T: NativePType,
59    for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
60{
61    type Output = PVector<T>;
62
63    fn checked_eval(self, rhs: &PVector<T>) -> Option<Self::Output> {
64        assert_eq!(self.len(), rhs.len());
65
66        // TODO(ngates): based on the true count of the validity, we may wish to short-circuit here
67        //  or choose a different implementation.
68        let validity = self.validity().bitand(rhs.validity());
69
70        let elements = CheckedArithmetic::<Op, _>::checked_eval(self.elements(), rhs.elements())?;
71        Some(PVector::new(elements, validity))
72    }
73}
74
75/// Implementation that attempts to downcast to a mutable vector and operates in-place against
76/// a scalar RHS value.
77impl<Op, T> CheckedArithmetic<Op, &T> for PVector<T>
78where
79    T: NativePType,
80    PVectorMut<T>: for<'a> CheckedArithmetic<Op, &'a T, Output = PVector<T>>,
81    for<'a> &'a PVector<T>: CheckedArithmetic<Op, &'a T, Output = PVector<T>>,
82{
83    type Output = PVector<T>;
84
85    fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
86        match self.try_into_mut() {
87            Ok(lhs) => CheckedArithmetic::<Op, _>::checked_eval(lhs, rhs),
88            Err(lhs) => CheckedArithmetic::<Op, _>::checked_eval(&lhs, rhs),
89        }
90    }
91}
92
93/// Implementation that operates in-place over a mutable vector against a scalar RHS value.
94impl<Op, T> CheckedArithmetic<Op, &T> for PVectorMut<T>
95where
96    T: NativePType,
97    BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
98{
99    type Output = PVector<T>;
100
101    fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
102        let (lhs_buffer, lhs_validity) = self.into_parts();
103        let validity = lhs_validity.freeze();
104
105        let elements = CheckedArithmetic::<Op, _>::checked_eval(lhs_buffer, rhs)?;
106
107        Some(PVector::new(elements, validity))
108    }
109}
110
111/// Implementation that allocates a new output vector against a scalar RHS value.
112impl<Op, T> CheckedArithmetic<Op, &T> for &PVector<T>
113where
114    T: NativePType,
115    for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
116{
117    type Output = PVector<T>;
118
119    fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
120        let buffer = CheckedArithmetic::<Op, _>::checked_eval(self.elements(), rhs)?;
121        Some(PVector::new(buffer, self.validity().clone()))
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use vortex_buffer::buffer;
128    use vortex_mask::Mask;
129    use vortex_vector::VectorOps;
130    use vortex_vector::primitive::PVector;
131
132    use crate::arithmetic::Add;
133    use crate::arithmetic::CheckedArithmetic;
134    use crate::arithmetic::Div;
135    use crate::arithmetic::Mul;
136    use crate::arithmetic::Sub;
137
138    #[test]
139    fn test_add_pvectors() {
140        let left = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
141        let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
142
143        let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right).unwrap();
144        assert_eq!(result.elements(), &buffer![11u32, 22, 33, 44]);
145    }
146
147    #[test]
148    fn test_add_scalar() {
149        let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
150        let result = CheckedArithmetic::<Add, _>::checked_eval(vec, &10).unwrap();
151        assert_eq!(result.elements(), &buffer![11u32, 12, 13, 14]);
152    }
153
154    #[test]
155    fn test_add_with_nulls() {
156        let left = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
157        let right = PVector::new(buffer![10u32, 20, 30], Mask::new_true(3));
158
159        let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right).unwrap();
160        // Validity is AND'd, so if either side is null, result is null
161        assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
162        assert_eq!(result.elements(), &buffer![11u32, 22, 33]);
163    }
164
165    #[test]
166    fn test_sub_pvectors() {
167        let left = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
168        let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
169
170        let result = CheckedArithmetic::<Sub, _>::checked_eval(left, &right).unwrap();
171        assert_eq!(result.elements(), &buffer![9u32, 18, 27, 36]);
172    }
173
174    #[test]
175    fn test_sub_scalar() {
176        let vec = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
177        let result = CheckedArithmetic::<Sub, _>::checked_eval(vec, &5).unwrap();
178        assert_eq!(result.elements(), &buffer![5u32, 15, 25, 35]);
179    }
180
181    #[test]
182    fn test_mul_pvectors() {
183        let left = PVector::new(buffer![2u32, 3, 4, 5], Mask::new_true(4));
184        let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
185
186        let result = CheckedArithmetic::<Mul, _>::checked_eval(left, &right).unwrap();
187        assert_eq!(result.elements(), &buffer![20u32, 60, 120, 200]);
188    }
189
190    #[test]
191    fn test_mul_scalar() {
192        let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
193        let result = CheckedArithmetic::<Mul, _>::checked_eval(vec, &10).unwrap();
194        assert_eq!(result.elements(), &buffer![10u32, 20, 30, 40]);
195    }
196
197    #[test]
198    fn test_div_pvectors() {
199        let left = PVector::new(buffer![100u32, 200, 300, 400], Mask::new_true(4));
200        let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
201
202        let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right).unwrap();
203        assert_eq!(result.elements(), &buffer![10u32, 10, 10, 10]);
204    }
205
206    #[test]
207    fn test_div_scalar() {
208        let vec = PVector::new(buffer![100u32, 200, 300, 400], Mask::new_true(4));
209        let result = CheckedArithmetic::<Div, _>::checked_eval(vec, &10).unwrap();
210        assert_eq!(result.elements(), &buffer![10u32, 20, 30, 40]);
211    }
212
213    #[test]
214    fn test_overflow_returns_none() {
215        let left = PVector::new(buffer![u8::MAX, 100], Mask::new_true(2));
216        let right = PVector::new(buffer![1u8, 50], Mask::new_true(2));
217
218        let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right);
219        assert!(result.is_none());
220    }
221
222    #[test]
223    fn test_div_by_zero_returns_none() {
224        let left = PVector::new(buffer![10u32, 20, 30], Mask::new_true(3));
225        let right = PVector::new(buffer![2u32, 0, 3], Mask::new_true(3));
226
227        let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right);
228        assert!(result.is_none());
229    }
230
231    #[test]
232    fn test_scalar_preserves_validity() {
233        let vec = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
234        let result = CheckedArithmetic::<Add, _>::checked_eval(vec, &10).unwrap();
235
236        assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
237        assert_eq!(result.elements(), &buffer![11u32, 12, 13]);
238    }
239}