vortex_compute/arithmetic/
buffer_checked.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_buffer::BufferMut;
6
7use crate::arithmetic::CheckedArithmetic;
8use crate::arithmetic::CheckedOperator;
9
10/// Implementation that attempts to downcast to a mutable buffer and operates in-place.
11impl<Op, T> CheckedArithmetic<Op, &Buffer<T>> for Buffer<T>
12where
13    T: Copy + num_traits::Zero,
14    BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
15    for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
16{
17    type Output = Buffer<T>;
18
19    fn checked_eval(self, rhs: &Buffer<T>) -> Option<Self::Output> {
20        match self.try_into_mut() {
21            Ok(lhs) => lhs.checked_eval(rhs),
22            Err(lhs) => (&lhs).checked_eval(rhs), // (&lhs) to delegate to borrowed impl
23        }
24    }
25}
26
27/// Implementation that operates in-place over a mutable buffer.
28impl<Op, T> CheckedArithmetic<Op, &Buffer<T>> for BufferMut<T>
29where
30    T: Copy + num_traits::Zero,
31    Op: CheckedOperator<T>,
32{
33    type Output = Buffer<T>;
34
35    fn checked_eval(self, rhs: &Buffer<T>) -> Option<Self::Output> {
36        assert_eq!(self.len(), rhs.len());
37
38        let mut i = 0;
39        let mut overflow = false;
40        let buffer = self
41            .map_each_in_place(|a| {
42                // SAFETY: lengths are equal, so index is in bounds
43                let b = unsafe { *rhs.get_unchecked(i) };
44                i += 1;
45
46                // On overflow, set flag and write zero
47                // We don't abort early because this code vectorizes better without the
48                // branch, and we expect overflow to be an exception rather than the norm.
49                Op::apply(&a, &b).unwrap_or_else(|| {
50                    overflow = true;
51                    T::zero()
52                })
53            })
54            .freeze();
55
56        (!overflow).then_some(buffer)
57    }
58}
59
60/// Implementation that allocates a new output buffer.
61impl<Op, T> CheckedArithmetic<Op> for &Buffer<T>
62where
63    T: Copy + num_traits::Zero,
64    Op: CheckedOperator<T>,
65{
66    type Output = Buffer<T>;
67
68    fn checked_eval(self, rhs: &Buffer<T>) -> Option<Self::Output> {
69        assert_eq!(self.len(), rhs.len());
70
71        let mut overflow = false;
72        let buffer =
73            Buffer::<T>::from_trusted_len_iter(self.iter().zip(rhs.iter()).map(|(a, b)| {
74                // On overflow, set flag and write zero
75                // We don't abort early because this code vectorizes better without the
76                // branch, and we expect overflow to be an exception rather than the norm.
77                Op::apply(a, b).unwrap_or_else(|| {
78                    overflow = true;
79                    T::zero()
80                })
81            }));
82        (!overflow).then_some(buffer)
83    }
84}
85
86/// Implementation that attempts to downcast to a mutable buffer and operates in-place against
87/// a scalar RHS value.
88impl<Op, T> CheckedArithmetic<Op, &T> for Buffer<T>
89where
90    T: Copy + num_traits::Zero,
91    BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
92    for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
93{
94    type Output = Buffer<T>;
95
96    fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
97        match self.try_into_mut() {
98            Ok(lhs) => lhs.checked_eval(rhs),
99            Err(lhs) => (&lhs).checked_eval(rhs),
100        }
101    }
102}
103
104/// Implementation that operates in-place over a mutable buffer against a scalar RHS value.
105impl<Op, T> CheckedArithmetic<Op, &T> for BufferMut<T>
106where
107    T: Copy + num_traits::Zero,
108    Op: CheckedOperator<T>,
109{
110    type Output = Buffer<T>;
111
112    fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
113        let mut overflow = false;
114        let buffer = self
115            .map_each_in_place(|a| {
116                Op::apply(&a, rhs).unwrap_or_else(|| {
117                    overflow = true;
118                    T::zero()
119                })
120            })
121            .freeze();
122
123        (!overflow).then_some(buffer)
124    }
125}
126
127/// Implementation that allocates a new output buffer operating against a scalar RHS value.
128impl<Op, T> CheckedArithmetic<Op, &T> for &Buffer<T>
129where
130    T: Copy + num_traits::Zero,
131    Op: CheckedOperator<T>,
132{
133    type Output = Buffer<T>;
134
135    fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
136        let mut overflow = false;
137        let buffer = Buffer::<T>::from_trusted_len_iter(self.iter().map(|a| {
138            Op::apply(a, rhs).unwrap_or_else(|| {
139                overflow = true;
140                T::zero()
141            })
142        }));
143
144        (!overflow).then_some(buffer)
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use vortex_buffer::buffer;
151
152    use crate::arithmetic::Add;
153    use crate::arithmetic::CheckedArithmetic;
154    use crate::arithmetic::Div;
155    use crate::arithmetic::Mul;
156    use crate::arithmetic::Sub;
157
158    #[test]
159    fn test_add_buffers() {
160        let left = buffer![1u32, 2, 3, 4];
161        let right = buffer![10u32, 20, 30, 40];
162
163        let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right).unwrap();
164        assert_eq!(result, buffer![11u32, 22, 33, 44]);
165    }
166
167    #[test]
168    fn test_add_scalar() {
169        let buf = buffer![1u32, 2, 3, 4];
170        let result = CheckedArithmetic::<Add, _>::checked_eval(buf, &10).unwrap();
171        assert_eq!(result, buffer![11u32, 12, 13, 14]);
172    }
173
174    #[test]
175    fn test_add_overflow() {
176        let left = buffer![u8::MAX, 100];
177        let right = buffer![1u8, 50];
178
179        let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right);
180        assert!(result.is_none());
181    }
182
183    #[test]
184    fn test_sub_buffers() {
185        let left = buffer![10u32, 20, 30, 40];
186        let right = buffer![1u32, 2, 3, 4];
187
188        let result = CheckedArithmetic::<Sub, _>::checked_eval(left, &right).unwrap();
189        assert_eq!(result, buffer![9u32, 18, 27, 36]);
190    }
191
192    #[test]
193    fn test_sub_scalar() {
194        let buf = buffer![10u32, 20, 30, 40];
195        let result = CheckedArithmetic::<Sub, _>::checked_eval(buf, &5).unwrap();
196        assert_eq!(result, buffer![5u32, 15, 25, 35]);
197    }
198
199    #[test]
200    fn test_sub_underflow() {
201        let left = buffer![5u32, 10];
202        let right = buffer![10u32, 5];
203
204        let result = CheckedArithmetic::<Sub, _>::checked_eval(left, &right);
205        assert!(result.is_none());
206    }
207
208    #[test]
209    fn test_mul_buffers() {
210        let left = buffer![2u32, 3, 4, 5];
211        let right = buffer![10u32, 20, 30, 40];
212
213        let result = CheckedArithmetic::<Mul, _>::checked_eval(left, &right).unwrap();
214        assert_eq!(result, buffer![20u32, 60, 120, 200]);
215    }
216
217    #[test]
218    fn test_mul_scalar() {
219        let buf = buffer![1u32, 2, 3, 4];
220        let result = CheckedArithmetic::<Mul, _>::checked_eval(buf, &10).unwrap();
221        assert_eq!(result, buffer![10u32, 20, 30, 40]);
222    }
223
224    #[test]
225    fn test_mul_overflow() {
226        let left = buffer![u8::MAX, 100];
227        let right = buffer![2u8, 3];
228
229        let result = CheckedArithmetic::<Mul, _>::checked_eval(left, &right);
230        assert!(result.is_none());
231    }
232
233    #[test]
234    fn test_div_buffers() {
235        let left = buffer![100u32, 200, 300, 400];
236        let right = buffer![10u32, 20, 30, 40];
237
238        let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right).unwrap();
239        assert_eq!(result, buffer![10u32, 10, 10, 10]);
240    }
241
242    #[test]
243    fn test_div_scalar() {
244        let buf = buffer![100u32, 200, 300, 400];
245        let result = CheckedArithmetic::<Div, _>::checked_eval(buf, &10).unwrap();
246        assert_eq!(result, buffer![10u32, 20, 30, 40]);
247    }
248
249    #[test]
250    fn test_div_by_zero() {
251        let left = buffer![10u32, 20, 30];
252        let right = buffer![2u32, 0, 3];
253
254        let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right);
255        assert!(result.is_none());
256    }
257
258    #[test]
259    fn test_div_scalar_by_zero() {
260        let buf = buffer![10u32, 20, 30];
261        let result = CheckedArithmetic::<Div, _>::checked_eval(buf, &0);
262        assert!(result.is_none());
263    }
264}