vortex_compute/arithmetic/
buffer_checked.rs1use vortex_buffer::{Buffer, BufferMut};
5
6use crate::arithmetic::{CheckedArithmetic, CheckedOperator};
7
8impl<Op, T> CheckedArithmetic<Op, &Buffer<T>> for Buffer<T>
10where
11 T: Copy + num_traits::Zero,
12 BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
13 for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
14{
15 type Output = Buffer<T>;
16
17 fn checked_eval(self, rhs: &Buffer<T>) -> Option<Self::Output> {
18 match self.try_into_mut() {
19 Ok(lhs) => lhs.checked_eval(rhs),
20 Err(lhs) => (&lhs).checked_eval(rhs), }
22 }
23}
24
25impl<Op, T> CheckedArithmetic<Op, &Buffer<T>> for BufferMut<T>
27where
28 T: Copy + num_traits::Zero,
29 Op: CheckedOperator<T>,
30{
31 type Output = Buffer<T>;
32
33 fn checked_eval(self, rhs: &Buffer<T>) -> Option<Self::Output> {
34 assert_eq!(self.len(), rhs.len());
35
36 let mut i = 0;
37 let mut overflow = false;
38 let buffer = self
39 .map_each_in_place(|a| {
40 let b = unsafe { *rhs.get_unchecked(i) };
42 i += 1;
43
44 Op::apply(&a, &b).unwrap_or_else(|| {
48 overflow = true;
49 T::zero()
50 })
51 })
52 .freeze();
53
54 (!overflow).then_some(buffer)
55 }
56}
57
58impl<Op, T> CheckedArithmetic<Op> for &Buffer<T>
60where
61 T: Copy + num_traits::Zero,
62 Op: CheckedOperator<T>,
63{
64 type Output = Buffer<T>;
65
66 fn checked_eval(self, rhs: &Buffer<T>) -> Option<Self::Output> {
67 assert_eq!(self.len(), rhs.len());
68
69 let mut overflow = false;
70 let buffer =
71 Buffer::<T>::from_trusted_len_iter(self.iter().zip(rhs.iter()).map(|(a, b)| {
72 Op::apply(a, b).unwrap_or_else(|| {
76 overflow = true;
77 T::zero()
78 })
79 }));
80 (!overflow).then_some(buffer)
81 }
82}
83
84impl<Op, T> CheckedArithmetic<Op, &T> for Buffer<T>
87where
88 T: Copy + num_traits::Zero,
89 BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
90 for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
91{
92 type Output = Buffer<T>;
93
94 fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
95 match self.try_into_mut() {
96 Ok(lhs) => lhs.checked_eval(rhs),
97 Err(lhs) => (&lhs).checked_eval(rhs),
98 }
99 }
100}
101
102impl<Op, T> CheckedArithmetic<Op, &T> for BufferMut<T>
104where
105 T: Copy + num_traits::Zero,
106 Op: CheckedOperator<T>,
107{
108 type Output = Buffer<T>;
109
110 fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
111 let mut overflow = false;
112 let buffer = self
113 .map_each_in_place(|a| {
114 Op::apply(&a, rhs).unwrap_or_else(|| {
115 overflow = true;
116 T::zero()
117 })
118 })
119 .freeze();
120
121 (!overflow).then_some(buffer)
122 }
123}
124
125impl<Op, T> CheckedArithmetic<Op, &T> for &Buffer<T>
127where
128 T: Copy + num_traits::Zero,
129 Op: CheckedOperator<T>,
130{
131 type Output = Buffer<T>;
132
133 fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
134 let mut overflow = false;
135 let buffer = Buffer::<T>::from_trusted_len_iter(self.iter().map(|a| {
136 Op::apply(a, rhs).unwrap_or_else(|| {
137 overflow = true;
138 T::zero()
139 })
140 }));
141
142 (!overflow).then_some(buffer)
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use vortex_buffer::buffer;
149
150 use crate::arithmetic::{Add, CheckedArithmetic, Div, Mul, Sub};
151
152 #[test]
153 fn test_add_buffers() {
154 let left = buffer![1u32, 2, 3, 4];
155 let right = buffer![10u32, 20, 30, 40];
156
157 let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right).unwrap();
158 assert_eq!(result, buffer![11u32, 22, 33, 44]);
159 }
160
161 #[test]
162 fn test_add_scalar() {
163 let buf = buffer![1u32, 2, 3, 4];
164 let result = CheckedArithmetic::<Add, _>::checked_eval(buf, &10).unwrap();
165 assert_eq!(result, buffer![11u32, 12, 13, 14]);
166 }
167
168 #[test]
169 fn test_add_overflow() {
170 let left = buffer![u8::MAX, 100];
171 let right = buffer![1u8, 50];
172
173 let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right);
174 assert!(result.is_none());
175 }
176
177 #[test]
178 fn test_sub_buffers() {
179 let left = buffer![10u32, 20, 30, 40];
180 let right = buffer![1u32, 2, 3, 4];
181
182 let result = CheckedArithmetic::<Sub, _>::checked_eval(left, &right).unwrap();
183 assert_eq!(result, buffer![9u32, 18, 27, 36]);
184 }
185
186 #[test]
187 fn test_sub_scalar() {
188 let buf = buffer![10u32, 20, 30, 40];
189 let result = CheckedArithmetic::<Sub, _>::checked_eval(buf, &5).unwrap();
190 assert_eq!(result, buffer![5u32, 15, 25, 35]);
191 }
192
193 #[test]
194 fn test_sub_underflow() {
195 let left = buffer![5u32, 10];
196 let right = buffer![10u32, 5];
197
198 let result = CheckedArithmetic::<Sub, _>::checked_eval(left, &right);
199 assert!(result.is_none());
200 }
201
202 #[test]
203 fn test_mul_buffers() {
204 let left = buffer![2u32, 3, 4, 5];
205 let right = buffer![10u32, 20, 30, 40];
206
207 let result = CheckedArithmetic::<Mul, _>::checked_eval(left, &right).unwrap();
208 assert_eq!(result, buffer![20u32, 60, 120, 200]);
209 }
210
211 #[test]
212 fn test_mul_scalar() {
213 let buf = buffer![1u32, 2, 3, 4];
214 let result = CheckedArithmetic::<Mul, _>::checked_eval(buf, &10).unwrap();
215 assert_eq!(result, buffer![10u32, 20, 30, 40]);
216 }
217
218 #[test]
219 fn test_mul_overflow() {
220 let left = buffer![u8::MAX, 100];
221 let right = buffer![2u8, 3];
222
223 let result = CheckedArithmetic::<Mul, _>::checked_eval(left, &right);
224 assert!(result.is_none());
225 }
226
227 #[test]
228 fn test_div_buffers() {
229 let left = buffer![100u32, 200, 300, 400];
230 let right = buffer![10u32, 20, 30, 40];
231
232 let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right).unwrap();
233 assert_eq!(result, buffer![10u32, 10, 10, 10]);
234 }
235
236 #[test]
237 fn test_div_scalar() {
238 let buf = buffer![100u32, 200, 300, 400];
239 let result = CheckedArithmetic::<Div, _>::checked_eval(buf, &10).unwrap();
240 assert_eq!(result, buffer![10u32, 20, 30, 40]);
241 }
242
243 #[test]
244 fn test_div_by_zero() {
245 let left = buffer![10u32, 20, 30];
246 let right = buffer![2u32, 0, 3];
247
248 let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right);
249 assert!(result.is_none());
250 }
251
252 #[test]
253 fn test_div_scalar_by_zero() {
254 let buf = buffer![10u32, 20, 30];
255 let result = CheckedArithmetic::<Div, _>::checked_eval(buf, &0);
256 assert!(result.is_none());
257 }
258}