vortex_compute/arithmetic/
buffer_checked.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::{Buffer, BufferMut};
5
6use crate::arithmetic::{CheckedArithmetic, CheckedOperator};
7
8/// Implementation that attempts to downcast to a mutable buffer and operates in-place.
9impl<Op, T> CheckedArithmetic<Op, &Buffer<T>> for Buffer<T>
10where
11    T: Copy + num_traits::Zero,
12    BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
13    for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
14{
15    type Output = Buffer<T>;
16
17    fn checked_eval(self, rhs: &Buffer<T>) -> Option<Self::Output> {
18        match self.try_into_mut() {
19            Ok(lhs) => lhs.checked_eval(rhs),
20            Err(lhs) => (&lhs).checked_eval(rhs), // (&lhs) to delegate to borrowed impl
21        }
22    }
23}
24
25/// Implementation that operates in-place over a mutable buffer.
26impl<Op, T> CheckedArithmetic<Op, &Buffer<T>> for BufferMut<T>
27where
28    T: Copy + num_traits::Zero,
29    Op: CheckedOperator<T>,
30{
31    type Output = Buffer<T>;
32
33    fn checked_eval(self, rhs: &Buffer<T>) -> Option<Self::Output> {
34        assert_eq!(self.len(), rhs.len());
35
36        let mut i = 0;
37        let mut overflow = false;
38        let buffer = self
39            .map_each_in_place(|a| {
40                // SAFETY: lengths are equal, so index is in bounds
41                let b = unsafe { *rhs.get_unchecked(i) };
42                i += 1;
43
44                // On overflow, set flag and write zero
45                // We don't abort early because this code vectorizes better without the
46                // branch, and we expect overflow to be an exception rather than the norm.
47                Op::apply(&a, &b).unwrap_or_else(|| {
48                    overflow = true;
49                    T::zero()
50                })
51            })
52            .freeze();
53
54        (!overflow).then_some(buffer)
55    }
56}
57
58/// Implementation that allocates a new output buffer.
59impl<Op, T> CheckedArithmetic<Op> for &Buffer<T>
60where
61    T: Copy + num_traits::Zero,
62    Op: CheckedOperator<T>,
63{
64    type Output = Buffer<T>;
65
66    fn checked_eval(self, rhs: &Buffer<T>) -> Option<Self::Output> {
67        assert_eq!(self.len(), rhs.len());
68
69        let mut overflow = false;
70        let buffer =
71            Buffer::<T>::from_trusted_len_iter(self.iter().zip(rhs.iter()).map(|(a, b)| {
72                // On overflow, set flag and write zero
73                // We don't abort early because this code vectorizes better without the
74                // branch, and we expect overflow to be an exception rather than the norm.
75                Op::apply(a, b).unwrap_or_else(|| {
76                    overflow = true;
77                    T::zero()
78                })
79            }));
80        (!overflow).then_some(buffer)
81    }
82}
83
84/// Implementation that attempts to downcast to a mutable buffer and operates in-place against
85/// a scalar RHS value.
86impl<Op, T> CheckedArithmetic<Op, &T> for Buffer<T>
87where
88    T: Copy + num_traits::Zero,
89    BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
90    for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
91{
92    type Output = Buffer<T>;
93
94    fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
95        match self.try_into_mut() {
96            Ok(lhs) => lhs.checked_eval(rhs),
97            Err(lhs) => (&lhs).checked_eval(rhs),
98        }
99    }
100}
101
102/// Implementation that operates in-place over a mutable buffer against a scalar RHS value.
103impl<Op, T> CheckedArithmetic<Op, &T> for BufferMut<T>
104where
105    T: Copy + num_traits::Zero,
106    Op: CheckedOperator<T>,
107{
108    type Output = Buffer<T>;
109
110    fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
111        let mut overflow = false;
112        let buffer = self
113            .map_each_in_place(|a| {
114                Op::apply(&a, rhs).unwrap_or_else(|| {
115                    overflow = true;
116                    T::zero()
117                })
118            })
119            .freeze();
120
121        (!overflow).then_some(buffer)
122    }
123}
124
125/// Implementation that allocates a new output buffer operating against a scalar RHS value.
126impl<Op, T> CheckedArithmetic<Op, &T> for &Buffer<T>
127where
128    T: Copy + num_traits::Zero,
129    Op: CheckedOperator<T>,
130{
131    type Output = Buffer<T>;
132
133    fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
134        let mut overflow = false;
135        let buffer = Buffer::<T>::from_trusted_len_iter(self.iter().map(|a| {
136            Op::apply(a, rhs).unwrap_or_else(|| {
137                overflow = true;
138                T::zero()
139            })
140        }));
141
142        (!overflow).then_some(buffer)
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use vortex_buffer::buffer;
149
150    use crate::arithmetic::{Add, CheckedArithmetic, Div, Mul, Sub};
151
152    #[test]
153    fn test_add_buffers() {
154        let left = buffer![1u32, 2, 3, 4];
155        let right = buffer![10u32, 20, 30, 40];
156
157        let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right).unwrap();
158        assert_eq!(result, buffer![11u32, 22, 33, 44]);
159    }
160
161    #[test]
162    fn test_add_scalar() {
163        let buf = buffer![1u32, 2, 3, 4];
164        let result = CheckedArithmetic::<Add, _>::checked_eval(buf, &10).unwrap();
165        assert_eq!(result, buffer![11u32, 12, 13, 14]);
166    }
167
168    #[test]
169    fn test_add_overflow() {
170        let left = buffer![u8::MAX, 100];
171        let right = buffer![1u8, 50];
172
173        let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right);
174        assert!(result.is_none());
175    }
176
177    #[test]
178    fn test_sub_buffers() {
179        let left = buffer![10u32, 20, 30, 40];
180        let right = buffer![1u32, 2, 3, 4];
181
182        let result = CheckedArithmetic::<Sub, _>::checked_eval(left, &right).unwrap();
183        assert_eq!(result, buffer![9u32, 18, 27, 36]);
184    }
185
186    #[test]
187    fn test_sub_scalar() {
188        let buf = buffer![10u32, 20, 30, 40];
189        let result = CheckedArithmetic::<Sub, _>::checked_eval(buf, &5).unwrap();
190        assert_eq!(result, buffer![5u32, 15, 25, 35]);
191    }
192
193    #[test]
194    fn test_sub_underflow() {
195        let left = buffer![5u32, 10];
196        let right = buffer![10u32, 5];
197
198        let result = CheckedArithmetic::<Sub, _>::checked_eval(left, &right);
199        assert!(result.is_none());
200    }
201
202    #[test]
203    fn test_mul_buffers() {
204        let left = buffer![2u32, 3, 4, 5];
205        let right = buffer![10u32, 20, 30, 40];
206
207        let result = CheckedArithmetic::<Mul, _>::checked_eval(left, &right).unwrap();
208        assert_eq!(result, buffer![20u32, 60, 120, 200]);
209    }
210
211    #[test]
212    fn test_mul_scalar() {
213        let buf = buffer![1u32, 2, 3, 4];
214        let result = CheckedArithmetic::<Mul, _>::checked_eval(buf, &10).unwrap();
215        assert_eq!(result, buffer![10u32, 20, 30, 40]);
216    }
217
218    #[test]
219    fn test_mul_overflow() {
220        let left = buffer![u8::MAX, 100];
221        let right = buffer![2u8, 3];
222
223        let result = CheckedArithmetic::<Mul, _>::checked_eval(left, &right);
224        assert!(result.is_none());
225    }
226
227    #[test]
228    fn test_div_buffers() {
229        let left = buffer![100u32, 200, 300, 400];
230        let right = buffer![10u32, 20, 30, 40];
231
232        let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right).unwrap();
233        assert_eq!(result, buffer![10u32, 10, 10, 10]);
234    }
235
236    #[test]
237    fn test_div_scalar() {
238        let buf = buffer![100u32, 200, 300, 400];
239        let result = CheckedArithmetic::<Div, _>::checked_eval(buf, &10).unwrap();
240        assert_eq!(result, buffer![10u32, 20, 30, 40]);
241    }
242
243    #[test]
244    fn test_div_by_zero() {
245        let left = buffer![10u32, 20, 30];
246        let right = buffer![2u32, 0, 3];
247
248        let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right);
249        assert!(result.is_none());
250    }
251
252    #[test]
253    fn test_div_scalar_by_zero() {
254        let buf = buffer![10u32, 20, 30];
255        let result = CheckedArithmetic::<Div, _>::checked_eval(buf, &0);
256        assert!(result.is_none());
257    }
258}